pax_global_header 0000666 0000000 0000000 00000000064 14746732372 0014531 g ustar 00root root 0000000 0000000 52 comment=3e79613493826e1eab59cfcfa2e4a54b56900c77
tree-0.1.9/ 0000775 0000000 0000000 00000000000 14746732372 0012477 5 ustar 00root root 0000000 0000000 tree-0.1.9/.github/ 0000775 0000000 0000000 00000000000 14746732372 0014037 5 ustar 00root root 0000000 0000000 tree-0.1.9/.github/workflows/ 0000775 0000000 0000000 00000000000 14746732372 0016074 5 ustar 00root root 0000000 0000000 tree-0.1.9/.github/workflows/build.yml 0000664 0000000 0000000 00000005370 14746732372 0017723 0 ustar 00root root 0000000 0000000 name: build
on:
push:
branches: [master]
pull_request:
branches: [master]
release:
types: [created]
workflow_dispatch:
jobs:
sdist:
name: sdist
runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: '3.11'
- name: Create sdist
run: |
python -m pip install --upgrade pip setuptools
python setup.py sdist
shell: bash
- name: List output directory
run: ls -lh dist/dm_tree*.tar.gz
shell: bash
- uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0
if: ${{ github.event_name == 'workflow_dispatch' || (github.event_name == 'release' && github.event.action == 'created') }}
with:
name: dm-tree-sdist
path: dist/dm_tree*.tar.gz
bdist-wheel:
name: Build wheels on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-24.04, macos-14, windows-2022] # latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
submodules: true
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: "3.11"
- name: Set up QEMU
if: runner.os == 'Linux'
uses: docker/setup-qemu-action@53851d14592bedcffcf25ea515637cff71ef929a # v3.3.0
with:
platforms: all
# This should be temporary
# xref https://github.com/docker/setup-qemu-action/issues/188
# xref https://github.com/tonistiigi/binfmt/issues/215
image: tonistiigi/binfmt:qemu-v8.1.5
- name: Install cibuildwheel
run: python -m pip install cibuildwheel==2.22.0
- name: Build wheels
run: python -m cibuildwheel --output-dir wheelhouse
env:
CIBW_ARCHS_LINUX: auto aarch64
CIBW_ARCHS_MACOS: universal2
CIBW_BUILD: "cp310-* cp311-* cp312-* cp313-* cp313t-*"
CIBW_BUILD_VERBOSITY: 1
CIBW_FREE_THREADED_SUPPORT: True
CIBW_PRERELEASE_PYTHONS: True
CIBW_SKIP: "*musllinux* *i686* *win32* *t-win*"
CIBW_TEST_COMMAND: pytest --pyargs tree
CIBW_TEST_REQUIRES: pytest
MAKEFLAGS: "-j$(nproc)"
- uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0
if: ${{ github.event_name == 'workflow_dispatch' || (github.event_name == 'release' && github.event.action == 'created') }}
with:
name: dm-tree-bdist-wheel-${{ matrix.os }}-${{ strategy.job-index }}
path: wheelhouse/*.whl
tree-0.1.9/CONTRIBUTING.md 0000664 0000000 0000000 00000002115 14746732372 0014727 0 ustar 00root root 0000000 0000000 # How to Contribute
We'd love to accept your patches and contributions to this project. There are
just a few small guidelines you need to follow.
## Contributor License Agreement
Contributions to this project must be accompanied by a Contributor License
Agreement. You (or your employer) retain the copyright to your contribution;
this simply gives us permission to use and redistribute your contributions as
part of the project. Head over to to see
your current agreements on file or to sign a new one.
You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.
## Code reviews
All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
information on using pull requests.
## Community Guidelines
This project follows
[Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
tree-0.1.9/LICENSE 0000664 0000000 0000000 00000026136 14746732372 0013514 0 ustar 00root root 0000000 0000000
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
tree-0.1.9/MANIFEST.in 0000664 0000000 0000000 00000000342 14746732372 0014234 0 ustar 00root root 0000000 0000000 # metadata
include LICENSE
include WORKSPACE
include CONTRIBUTING.md
include README.md
# python package requirements
include requirements*.txt
# tree files
recursive-include . CMakeLists.txt *.cc *.cpp *.h *.sh *.py *.cmake
tree-0.1.9/README.md 0000664 0000000 0000000 00000001736 14746732372 0013765 0 ustar 00root root 0000000 0000000 # Tree
`tree` is a library for working with nested data structures. In a way, `tree`
generalizes the builtin `map` function which only supports flat sequences,
and allows to apply a function to each "leaf" preserving the overall
structure.
```python
>>> import tree
>>> structure = [[1], [[[2, 3]]], [4]]
>>> tree.flatten(structure)
[1, 2, 3, 4]
>>> tree.map_structure(lambda v: v**2, structure)
[[1], [[[4, 9]]], [16]]
```
`tree` is backed by an optimized C++ implementation suitable for use in
demanding applications, such as machine learning models.
## Installation
From PyPI:
```shell
$ pip install dm-tree
```
Directly from github using pip:
```shell
$ pip install git+git://github.com/deepmind/tree.git
```
Build from source:
```shell
$ python setup.py install
```
## Support
If you are having issues, please let us know by filing an issue on our
[issue tracker](https://github.com/deepmind/tree/issues).
## License
The project is licensed under the Apache 2.0 license.
tree-0.1.9/docs/ 0000775 0000000 0000000 00000000000 14746732372 0013427 5 ustar 00root root 0000000 0000000 tree-0.1.9/docs/Makefile 0000664 0000000 0000000 00000001105 14746732372 0015064 0 ustar 00root root 0000000 0000000 # Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line.
SPHINXOPTS =
SPHINXBUILD = sphinx-build
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
tree-0.1.9/docs/api.rst 0000664 0000000 0000000 00000002361 14746732372 0014734 0 ustar 00root root 0000000 0000000 #############
API Reference
#############
All ``tree`` functions operate on nested tree-like structures. A *structure*
is recursively defined as::
Structure = Union[
Any,
Sequence['Structure'],
Mapping[Any, 'Structure'],
'AnyNamedTuple',
]
.. TODO(slebedev): Support @dataclass classes if we make @attr.s
.. support public.
A single (non-nested) Python object is a perfectly valid structure::
>>> tree.map_structure(lambda v: v * 2, 42)
84
>>> tree.flatten(42)
[42]
You could check whether a structure is actually nested via
:func:`~tree.is_nested`::
>>> tree.is_nested(42)
False
>>> tree.is_nested([42])
True
Note that ``tree`` only supports acyclic structures. The behavior for
structures with cycle references is undefined.
.. currentmodule:: tree
.. autofunction:: is_nested
.. autofunction:: assert_same_structure
.. autofunction:: unflatten_as
.. autofunction:: flatten
.. autofunction:: flatten_up_to
.. autofunction:: flatten_with_path
.. autofunction:: flatten_with_path_up_to
.. autofunction:: map_structure
.. autofunction:: map_structure_up_to
.. autofunction:: map_structure_with_path
.. autofunction:: map_structure_with_path_up_to
.. autofunction:: traverse
.. autodata:: MAP_TO_NONE
tree-0.1.9/docs/changes.rst 0000664 0000000 0000000 00000002454 14746732372 0015576 0 ustar 00root root 0000000 0000000 #########
Changelog
#########
Version 0.1.9
=============
Released 2025-01-30
* Dropped support for Python <3.10.
Version 0.1.8
=============
Released 2022-12-19
* Bumped pybind11 to v2.10.1 to support Python 3.11.
* Dropped support for Python 3.6.
Version 0.1.7
=============
Released 2022-04-10
* The build is now done via CMake instead of Bazel.
Version 0.1.6
=============
Released 2021-04-12
* Dropped support for Python 2.X.
* Added a generalization of ``tree.traverse`` which keeps track of the
current path during traversal.
Version 0.1.5
=============
Released 2020-04-30
* Added a new function ``tree.traverse`` which allows to traverse a nested
structure and apply a function to each subtree.
Version 0.1.4
=============
Released 2020-03-27
* Added support for ``types.MappingProxyType`` on Python 3.X.
Version 0.1.3
=============
Released 2020-01-30
* Fixed ``ImportError`` when ``wrapt`` was not available.
Version 0.1.2
=============
Released 2020-01-29
* Added support for ``wrapt.ObjectWrapper`` objects.
* Added ``StructureKV[K, V]`` and ``Structure = Structure[Text, V]`` types.
Version 0.1.1
=============
Released 2019-11-07
* Ensured that the produced Linux wheels are manylinux2010-compatible.
Version 0.1.0
=============
Released 2019-11-05
* Initial public release.
tree-0.1.9/docs/conf.py 0000664 0000000 0000000 00000007642 14746732372 0014737 0 ustar 00root root 0000000 0000000 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Configuration file for the Sphinx documentation builder."""
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# http://www.sphinx-doc.org/en/master/config
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
# pylint: disable=g-bad-import-order
# pylint: disable=g-import-not-at-top
import datetime
import inspect
import os
import sys
sys.path.insert(0, os.path.abspath('../'))
import tree
# -- Project information -----------------------------------------------------
project = 'Tree'
copyright = f'{datetime.date.today().year}, DeepMind' # pylint: disable=redefined-builtin
author = 'DeepMind'
# -- General configuration ---------------------------------------------------
master_doc = 'index'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.linkcode',
'sphinx.ext.napoleon',
'sphinx.ext.doctest'
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
# -- Options for autodoc -----------------------------------------------------
autodoc_default_options = {
'member-order': 'bysource',
'special-members': True,
'exclude-members': '__repr__, __str__, __weakref__',
}
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
html_theme_options = {
# 'collapse_navigation': False,
# 'sticky_navigation': False,
}
# -- Options for doctest -----------------------------------------------------
doctest_global_setup = '''
import collections
import numpy as np
import tree
'''
# -- Source code links -------------------------------------------------------
def linkcode_resolve(domain, info):
"""Resolve a GitHub URL corresponding to Python object."""
if domain != 'py':
return None
try:
mod = sys.modules[info['module']]
except ImportError:
return None
obj = mod
try:
for attr in info['fullname'].split('.'):
obj = getattr(obj, attr)
except AttributeError:
return None
else:
obj = inspect.unwrap(obj)
try:
filename = inspect.getsourcefile(obj)
except TypeError:
return None
try:
source, lineno = inspect.getsourcelines(obj)
except OSError:
return None
# TODO(slebedev): support tags after we release an initial version.
return 'https://github.com/deepmind/tree/blob/master/tree/%s#L%d#L%d' % (
os.path.relpath(filename, start=os.path.dirname(
tree.__file__)), lineno, lineno + len(source) - 1)
tree-0.1.9/docs/index.rst 0000664 0000000 0000000 00000001571 14746732372 0015274 0 ustar 00root root 0000000 0000000 ##################
Tree Documentation
##################
.. toctree::
:maxdepth: 2
:hidden:
api
changes
recipes
``tree`` is a library for working with nested data structures. In a way,
``tree`` generalizes the builtin :func:`map` function which only supports
flat sequences, and allows to apply a function to each "leaf" preserving
the overall structure.
Here's a quick example::
>>> tree.map_structure(lambda v: v**2, [[1], [[[2, 3]]], [4]])
[[1], [[[4, 9]]], [16]]
.. note::
``tree`` has originally been part of TensorFlow and is available
as ``tf.nest``.
Installation
============
Install ``tree`` by running::
$ pip install dm-tree
Support
=======
If you are having issues, please let us know by filing an issue on our
`issue tracker `_.
License
=======
Tree is licensed under the Apache 2.0 License.
tree-0.1.9/docs/recipes.rst 0000664 0000000 0000000 00000002057 14746732372 0015617 0 ustar 00root root 0000000 0000000 ############
Recipes
############
Concatenate nested array structures
===================================
>>> tree.map_structure(lambda *args: np.concatenate(args, axis=1),
... {'a': np.ones((2, 1))},
... {'a': np.zeros((2, 1))})
{'a': array([[1., 0.],
[1., 0.]])}
>>> tree.map_structure(lambda *args: np.concatenate(args, axis=0),
... {'a': np.ones((2, 1))},
... {'a': np.zeros((2, 1))})
{'a': array([[1.],
[1.],
[0.],
[0.]])}
Exclude "meta" keys while mapping across structures
===================================================
>>> d1 = {'key_to_exclude': None, 'a': 1}
>>> d2 = {'key_to_exclude': None, 'a': 2}
>>> d3 = {'a': 3}
>>> tree.map_structure_up_to({'a': True}, lambda x, y, z: x+y+z, d1, d2, d3)
{'a': 6}
Broadcast a value across a reference structure
==============================================
>>> reference_tree = {'a': 1, 'b': (2, 3)}
>>> value = np.inf
>>> tree.map_structure(lambda _: value, reference_tree)
{'a': inf, 'b': (inf, inf)}
tree-0.1.9/docs/requirements.txt 0000664 0000000 0000000 00000000046 14746732372 0016713 0 ustar 00root root 0000000 0000000 sphinx>=2.0.1
sphinx_rtd_theme>=0.4.3
tree-0.1.9/readthedocs.yml 0000664 0000000 0000000 00000000426 14746732372 0015511 0 ustar 00root root 0000000 0000000 # Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
version: 2
sphinx:
builder: html
configuration: docs/conf.py
fail_on_warning: false
python:
version: 3.7
install:
- requirements: docs/requirements.txt
tree-0.1.9/setup.py 0000664 0000000 0000000 00000013233 14746732372 0014213 0 ustar 00root root 0000000 0000000 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Setup for pip package."""
import os
import platform
import shutil
import subprocess
import sys
import sysconfig
import setuptools
from setuptools.command import build_ext
here = os.path.dirname(os.path.abspath(__file__))
def _get_tree_version():
"""Parse the version string from tree/__init__.py."""
with open(os.path.join(here, 'tree', '__init__.py')) as f:
try:
version_line = next(line for line in f if line.startswith('__version__'))
except StopIteration:
raise ValueError('__version__ not defined in tree/__init__.py')
else:
ns = {}
exec(version_line, ns) # pylint: disable=exec-used
return ns['__version__']
class CMakeExtension(setuptools.Extension):
"""An extension with no sources.
We do not want distutils to handle any of the compilation (instead we rely
on CMake), so we always pass an empty list to the constructor.
"""
def __init__(self, name, source_dir=''):
super().__init__(name, sources=[])
self.source_dir = os.path.abspath(source_dir)
class BuildCMakeExtension(build_ext.build_ext):
"""Our custom build_ext command.
Uses CMake to build extensions instead of a bare compiler (e.g. gcc, clang).
"""
def run(self):
self._check_build_environment()
for ext in self.extensions:
self.build_extension(ext)
def _check_build_environment(self):
"""Check for required build tools: CMake, C++ compiler, and python dev."""
try:
subprocess.check_call(['cmake', '--version'])
except OSError as e:
ext_names = ', '.join(e.name for e in self.extensions)
raise RuntimeError(
f'CMake must be installed to build the following extensions: {ext_names}'
) from e
print('Found CMake')
def build_extension(self, ext):
extension_dir = os.path.abspath(
os.path.dirname(self.get_ext_fullpath(ext.name)))
build_cfg = 'Debug' if self.debug else 'Release'
cmake_args = [
f'-DPython3_ROOT_DIR={sys.prefix}',
f'-DPython3_EXECUTABLE={sys.executable}',
f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extension_dir}',
f'-DCMAKE_BUILD_TYPE={build_cfg}'
]
if platform.system() != 'Windows':
cmake_args.extend([
f'-DPython3_LIBRARY={sysconfig.get_paths()["stdlib"]}',
f'-DPython3_INCLUDE_DIR={sysconfig.get_paths()["include"]}',
])
if platform.system() == 'Darwin' and os.environ.get('ARCHFLAGS'):
osx_archs = []
if '-arch x86_64' in os.environ['ARCHFLAGS']:
osx_archs.append('x86_64')
if '-arch arm64' in os.environ['ARCHFLAGS']:
osx_archs.append('arm64')
cmake_args.append(f'-DCMAKE_OSX_ARCHITECTURES={";".join(osx_archs)}')
os.makedirs(self.build_temp, exist_ok=True)
subprocess.check_call(
['cmake', '-S', ext.source_dir, '-B', self.build_temp] + cmake_args)
num_jobs = ()
if self.parallel:
num_jobs = (f'-j{self.parallel}',)
subprocess.check_call([
'cmake', '--build', self.build_temp, *num_jobs, '--config', build_cfg
])
# Force output to /. Amends CMake multigenerator output paths
# on Windows and avoids Debug/ and Release/ subdirs, which is CMake default.
tree_dir = os.path.join(extension_dir, 'tree') # pylint:disable=unreachable
for cfg in ('Release', 'Debug'):
cfg_dir = os.path.join(extension_dir, cfg)
if os.path.isdir(cfg_dir):
for f in os.listdir(cfg_dir):
shutil.move(os.path.join(cfg_dir, f), tree_dir)
setuptools.setup(
name='dm-tree',
version=_get_tree_version(),
url='https://github.com/deepmind/tree',
description='Tree is a library for working with nested data structures.',
author='DeepMind',
author_email='tree-copybara@google.com',
long_description=open(os.path.join(here, 'README.md')).read(),
long_description_content_type='text/markdown',
# Contained modules and scripts.
packages=setuptools.find_packages(),
python_requires='>=3.10',
install_requires=[
'absl-py>=0.6.1',
'attrs>=18.2.0',
'numpy>=1.21',
"numpy>=1.21.2; python_version>='3.10'",
"numpy>=1.23.3; python_version>='3.11'",
"numpy>=1.26.0; python_version>='3.12'",
"numpy>=2.1.0; python_version>='3.13'",
'wrapt>=1.11.2',
],
test_suite='tree',
cmdclass=dict(build_ext=BuildCMakeExtension),
ext_modules=[CMakeExtension('_tree', source_dir='tree')],
zip_safe=False,
# PyPI package information.
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3.12',
'Programming Language :: Python :: 3.13',
'Topic :: Scientific/Engineering :: Mathematics',
'Topic :: Software Development :: Libraries',
],
license='Apache 2.0',
keywords='tree nest flatten',
)
tree-0.1.9/tree/ 0000775 0000000 0000000 00000000000 14746732372 0013436 5 ustar 00root root 0000000 0000000 tree-0.1.9/tree/CMakeLists.txt 0000664 0000000 0000000 00000007762 14746732372 0016212 0 ustar 00root root 0000000 0000000 # Version >= 3.24 required for new `FindPython` module and `FIND_PACKAGE_ARGS`
# keyword of `FetchContent` module.
# https://cmake.org/cmake/help/v3.24/release/3.24.html
cmake_minimum_required(VERSION 3.24)
cmake_policy(SET CMP0135 NEW)
project (tree LANGUAGES CXX)
option(USE_SYSTEM_ABSEIL "Force use of system abseil-cpp" OFF)
option(USE_SYSTEM_PYBIND11 "Force use of system pybind11" OFF)
# Required for Python.h and python binding.
find_package(Python3 COMPONENTS Interpreter Development)
include_directories(SYSTEM ${Python3_INCLUDE_DIRS})
if(Python3_VERSION VERSION_LESS "3.6.0")
message(FATAL_ERROR
"Python found ${Python3_VERSION} < 3.6.0")
endif()
# Use C++14 standard.
set(CMAKE_CXX_STANDARD 14 CACHE STRING "C++ version selection")
# Position-independent code is needed for Python extension modules.
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
# Set default build type.
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE RELEASE
CACHE STRING "Choose the type of build: Debug Release."
FORCE)
endif()
message("Current build type is: ${CMAKE_BUILD_TYPE}")
message("PROJECT_BINARY_DIR is: ${PROJECT_BINARY_DIR}")
if (NOT (WIN32 OR MSVC))
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug")
# Basic build for debugging (default).
# -Og enables optimizations that do not interfere with debugging.
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Og")
endif()
if(${CMAKE_BUILD_TYPE} STREQUAL "Release")
# Optimized release build: turn off debug runtime checks
# and turn on highest speed optimizations.
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNDEBUG -O3")
endif()
endif()
if(APPLE)
# On MacOS:
# -undefined dynamic_lookup is necessary for pybind11 linking
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-everything -w -undefined dynamic_lookup")
# On MacOS, we need this so that CMake will use the right Python if the user
# has a virtual environment active
set (CMAKE_FIND_FRAMEWORK LAST)
endif()
# Use `FetchContent` module to manage all external dependencies (i.e.
# abseil-cpp and pybind11).
include(FetchContent)
# Needed to disable Abseil tests.
set(BUILD_TESTING OFF)
# Try to find abseil-cpp package system-wide first.
if (USE_SYSTEM_ABSEIL)
message(STATUS "Use system abseil-cpp: ${USE_SYSTEM_ABSEIL}")
set(ABSEIL_FIND_PACKAGE_ARGS FIND_PACKAGE_ARGS)
endif()
# Include abseil-cpp.
set(ABSEIL_REPO https://github.com/abseil/abseil-cpp)
set(ABSEIL_CMAKE_ARGS
"-DCMAKE_INSTALL_PREFIX=${CMAKE_SOURCE_DIR}/abseil-cpp"
"-DCMAKE_CXX_STANDARD=${CMAKE_CXX_STANDARD}"
"-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}"
"-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}"
"-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}"
"-DCMAKE_POSITION_INDEPENDENT_CODE=${CMAKE_POSITION_INDEPENDENT_CODE}"
"-DLIBRARY_OUTPUT_PATH=${CMAKE_SOURCE_DIR}/abseil-cpp/lib"
"-DABSL_PROPAGATE_CXX_STD=ON")
if(DEFINED CMAKE_OSX_ARCHITECTURES)
set(ABSEIL_CMAKE_ARGS
${ABSEIL_CMAKE_ARGS}
"-DCMAKE_OSX_ARCHITECTURES=${CMAKE_OSX_ARCHITECTURES}")
endif()
FetchContent_Declare(
absl
URL ${ABSEIL_REPO}/archive/refs/tags/20220623.2.tar.gz
URL_HASH SHA256=773652c0fc276bcd5c461668dc112d0e3b6cde499600bfe3499c5fdda4ed4a5b
CMAKE_ARGS ${ABSEIL_CMAKE_ARGS}
EXCLUDE_FROM_ALL
${ABSEIL_FIND_PACKAGE_ARGS})
# Try to find pybind11 package system-wide first.
if (USE_SYSTEM_PYBIND11)
message(STATUS "Use system pybind11: ${USE_SYSTEM_PYBIND11}")
set(PYBIND11_FIND_PACKAGE_ARGS FIND_PACKAGE_ARGS)
endif()
FetchContent_Declare(
pybind11
URL https://github.com/pybind/pybind11/archive/refs/tags/v2.10.1.tar.gz
URL_HASH SHA256=111014b516b625083bef701df7880f78c2243835abdb263065b6b59b960b6bad
${PYBIND11_FIND_PACKAGE_ARGS})
FetchContent_MakeAvailable(absl pybind11)
# Define pybind11 tree module.
pybind11_add_module(_tree tree.h tree.cc)
target_link_libraries(
_tree
PRIVATE
absl::int128
absl::raw_hash_set
absl::raw_logging_internal
absl::strings
absl::throw_delegate)
# Make the module private to tree package.
set_target_properties(_tree PROPERTIES OUTPUT_NAME tree/_tree)
tree-0.1.9/tree/__init__.py 0000664 0000000 0000000 00000101512 14746732372 0015547 0 ustar 00root root 0000000 0000000 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for working with nested data structures."""
from collections import abc as collections_abc
import logging
import sys
from typing import Mapping, Sequence, TypeVar, Union
from .sequence import _is_attrs
from .sequence import _is_namedtuple
from .sequence import _sequence_like
from .sequence import _sorted
# pylint: disable=g-import-not-at-top
try:
import wrapt
ObjectProxy = wrapt.ObjectProxy
except ImportError:
class ObjectProxy(object):
"""Stub-class for `wrapt.ObjectProxy``."""
try:
from tree import _tree
except ImportError:
if "sphinx" not in sys.modules:
raise
_tree = None
# pylint: enable=g-import-not-at-top
__all__ = [
"is_nested",
"assert_same_structure",
"unflatten_as",
"flatten",
"flatten_up_to",
"flatten_with_path",
"flatten_with_path_up_to",
"map_structure",
"map_structure_up_to",
"map_structure_with_path",
"map_structure_with_path_up_to",
"traverse",
"MAP_TO_NONE",
]
__version__ = "0.1.9"
# Note: this is *not* the same as `six.string_types`, which in Python3 is just
# `(str,)` (i.e. it does not include byte strings).
_TEXT_OR_BYTES = (str, bytes)
_SHALLOW_TREE_HAS_INVALID_KEYS = (
"The shallow_tree's keys are not a subset of the input_tree's keys. The "
"shallow_tree has the following keys that are not in the input_tree: {}.")
_STRUCTURES_HAVE_MISMATCHING_TYPES = (
"The two structures don't have the same sequence type. Input structure has "
"type {input_type}, while shallow structure has type {shallow_type}.")
_STRUCTURES_HAVE_MISMATCHING_LENGTHS = (
"The two structures don't have the same sequence length. Input "
"structure has length {input_length}, while shallow structure has length "
"{shallow_length}."
)
_INPUT_TREE_SMALLER_THAN_SHALLOW_TREE = (
"The input_tree has fewer elements than the shallow_tree. Input structure "
"has length {input_size}, while shallow structure has length "
"{shallow_size}.")
_IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ = (
"If shallow structure is a sequence, input must also be a sequence. "
"Input has type: {}.")
_IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH = (
"If shallow structure is a sequence, input must also be a sequence. "
"Input at path: {path} has type: {input_type}.")
K = TypeVar("K")
V = TypeVar("V")
# A generic monomorphic structure type, e.g. ``StructureKV[str, int]``
# is an arbitrarily nested structure where keys must be of type ``str``
# and values are integers.
StructureKV = Union[
Sequence["StructureKV[K, V]"],
Mapping[K, "StructureKV[K, V]"],
V,
]
Structure = StructureKV[str, V]
def _get_attrs_items(obj):
"""Returns a list of (name, value) pairs from an attrs instance.
The list will be sorted by name.
Args:
obj: an object.
Returns:
A list of (attr_name, attr_value) pairs.
"""
return [(attr.name, getattr(obj, attr.name))
for attr in obj.__class__.__attrs_attrs__]
def _yield_value(iterable):
for _, v in _yield_sorted_items(iterable):
yield v
def _yield_sorted_items(iterable):
"""Yield (key, value) pairs for `iterable` in a deterministic order.
For Sequences, the key will be an int, the array index of a value.
For Mappings, the key will be the dictionary key.
For objects (e.g. namedtuples), the key will be the attribute name.
In all cases, the keys will be iterated in sorted order.
Args:
iterable: an iterable.
Yields:
The iterable's (key, value) pairs, in order of sorted keys.
"""
if isinstance(iterable, collections_abc.Mapping):
# Iterate through dictionaries in a deterministic order by sorting the
# keys. Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
# ordered and plain dicts (e.g., flattening a dict but using a
# corresponding `OrderedDict` to pack it back).
for key in _sorted(iterable):
yield key, iterable[key]
elif _is_attrs(iterable):
for item in _get_attrs_items(iterable):
yield item
elif _is_namedtuple(iterable):
for field in iterable._fields:
yield (field, getattr(iterable, field))
else:
for item in enumerate(iterable):
yield item
def _num_elements(structure):
if _is_attrs(structure):
return len(getattr(structure.__class__, "__attrs_attrs__"))
else:
return len(structure)
def is_nested(structure):
"""Checks if a given structure is nested.
>>> tree.is_nested(42)
False
>>> tree.is_nested({"foo": 42})
True
Args:
structure: A structure to check.
Returns:
`True` if a given structure is nested, i.e. is a sequence, a mapping,
or a namedtuple, and `False` otherwise.
"""
return _tree.is_sequence(structure)
def flatten(structure):
r"""Flattens a possibly nested structure into a list.
>>> tree.flatten([[1, 2, 3], [4, [5], [[6]]]])
[1, 2, 3, 4, 5, 6]
If `structure` is not nested, the result is a single-element list.
>>> tree.flatten(None)
[None]
>>> tree.flatten(1)
[1]
In the case of dict instances, the sequence consists of the values,
sorted by key to ensure deterministic behavior. This is true also for
:class:`~collections.OrderedDict` instances: their sequence order is
ignored, the sorting order of keys is used instead. The same convention
is followed in :func:`~tree.unflatten`. This correctly unflattens dicts
and ``OrderedDict``\ s after they have been flattened, and also allows
flattening an ``OrderedDict`` and then unflattening it back using a
corresponding plain dict, or vice-versa.
Dictionaries with non-sortable keys cannot be flattened.
>>> tree.flatten({100: 'world!', 6: 'Hello'})
['Hello', 'world!']
Args:
structure: An arbitrarily nested structure.
Returns:
A list, the flattened version of the input `structure`.
Raises:
TypeError: If `structure` is or contains a mapping with non-sortable keys.
"""
return _tree.flatten(structure)
class _DotString(object):
def __str__(self):
return "."
def __repr__(self):
return "."
_DOT = _DotString()
def assert_same_structure(a, b, check_types=True):
"""Asserts that two structures are nested in the same way.
>>> tree.assert_same_structure([(0, 1)], [(2, 3)])
Note that namedtuples with identical name and fields are always considered
to have the same shallow structure (even with `check_types=True`).
>>> Foo = collections.namedtuple('Foo', ['a', 'b'])
>>> AlsoFoo = collections.namedtuple('Foo', ['a', 'b'])
>>> tree.assert_same_structure(Foo(0, 1), AlsoFoo(2, 3))
Named tuples with different names are considered to have different shallow
structures:
>>> Bar = collections.namedtuple('Bar', ['a', 'b'])
>>> tree.assert_same_structure(Foo(0, 1), Bar(2, 3))
Traceback (most recent call last):
...
TypeError: The two structures don't have the same nested structure.
...
Args:
a: an arbitrarily nested structure.
b: an arbitrarily nested structure.
check_types: if `True` (default) types of sequences are checked as
well, including the keys of dictionaries. If set to `False`, for example
a list and a tuple of objects will look the same if they have the same
size. Note that namedtuples with identical name and fields are always
considered to have the same shallow structure.
Raises:
ValueError: If the two structures do not have the same number of elements or
if the two structures are not nested in the same way.
TypeError: If the two structures differ in the type of sequence in any of
their substructures. Only possible if `check_types` is `True`.
"""
try:
_tree.assert_same_structure(a, b, check_types)
except (ValueError, TypeError) as e:
str1 = str(map_structure(lambda _: _DOT, a))
str2 = str(map_structure(lambda _: _DOT, b))
raise type(e)("%s\n"
"Entire first structure:\n%s\n"
"Entire second structure:\n%s"
% (e, str1, str2))
def _packed_nest_with_indices(structure, flat, index):
"""Helper function for ``unflatten_as``.
Args:
structure: Substructure (list / tuple / dict) to mimic.
flat: Flattened values to output substructure for.
index: Index at which to start reading from flat.
Returns:
The tuple (new_index, child), where:
* new_index - the updated index into `flat` having processed `structure`.
* packed - the subset of `flat` corresponding to `structure`,
having started at `index`, and packed into the same nested
format.
Raises:
ValueError: if `structure` contains more elements than `flat`
(assuming indexing starts from `index`).
"""
packed = []
for s in _yield_value(structure):
if is_nested(s):
new_index, child = _packed_nest_with_indices(s, flat, index)
packed.append(_sequence_like(s, child))
index = new_index
else:
packed.append(flat[index])
index += 1
return index, packed
def unflatten_as(structure, flat_sequence):
r"""Unflattens a sequence into a given structure.
>>> tree.unflatten_as([[1, 2], [[3], [4]]], [5, 6, 7, 8])
[[5, 6], [[7], [8]]]
If `structure` is a scalar, `flat_sequence` must be a single-element list;
in this case the return value is ``flat_sequence[0]``.
>>> tree.unflatten_as(None, [1])
1
If `structure` is or contains a dict instance, the keys will be sorted to
pack the flat sequence in deterministic order. This is true also for
:class:`~collections.OrderedDict` instances: their sequence order is
ignored, the sorting order of keys is used instead. The same convention
is followed in :func:`~tree.flatten`. This correctly unflattens dicts
and ``OrderedDict``\ s after they have been flattened, and also allows
flattening an ``OrderedDict`` and then unflattening it back using a
corresponding plain dict, or vice-versa.
Dictionaries with non-sortable keys cannot be unflattened.
>>> tree.unflatten_as({1: None, 2: None}, ['Hello', 'world!'])
{1: 'Hello', 2: 'world!'}
Args:
structure: Arbitrarily nested structure.
flat_sequence: Sequence to unflatten.
Returns:
`flat_sequence` unflattened into `structure`.
Raises:
ValueError: If `flat_sequence` and `structure` have different
element counts.
TypeError: If `structure` is or contains a mapping with non-sortable keys.
"""
if not is_nested(flat_sequence):
raise TypeError("flat_sequence must be a sequence not a {}:\n{}".format(
type(flat_sequence), flat_sequence))
if not is_nested(structure):
if len(flat_sequence) != 1:
raise ValueError("Structure is a scalar but len(flat_sequence) == %d > 1"
% len(flat_sequence))
return flat_sequence[0]
flat_structure = flatten(structure)
if len(flat_structure) != len(flat_sequence):
raise ValueError(
"Could not pack sequence. Structure had %d elements, but flat_sequence "
"had %d elements. Structure: %s, flat_sequence: %s."
% (len(flat_structure), len(flat_sequence), structure, flat_sequence))
_, packed = _packed_nest_with_indices(structure, flat_sequence, 0)
return _sequence_like(structure, packed)
def map_structure(func, *structures, **kwargs): # pylint: disable=redefined-builtin
"""Maps `func` through given structures.
>>> structure = [[1], [2], [3]]
>>> tree.map_structure(lambda v: v**2, structure)
[[1], [4], [9]]
>>> tree.map_structure(lambda x, y: x * y, structure, structure)
[[1], [4], [9]]
>>> Foo = collections.namedtuple('Foo', ['a', 'b'])
>>> structure = Foo(a=1, b=2)
>>> tree.map_structure(lambda v: v * 2, structure)
Foo(a=2, b=4)
Args:
func: A callable that accepts as many arguments as there are structures.
*structures: Arbitrarily nested structures of the same layout.
**kwargs: The only valid keyword argument is `check_types`. If `True`
(default) the types of components within the structures have
to be match, e.g. ``tree.map_structure(func, [1], (1,))`` will raise
a `TypeError`, otherwise this is not enforced. Note that namedtuples
with identical name and fields are considered to be the same type.
Returns:
A new structure with the same layout as the given ones. If the
`structures` have components of varying types, the resulting structure
will use the same types as ``structures[0]``.
Raises:
TypeError: If `func` is not callable.
ValueError: If the two structures do not have the same number of elements or
if the two structures are not nested in the same way.
TypeError: If `check_types` is `True` and any two `structures`
differ in the types of their components.
ValueError: If no structures were given or if a keyword argument other
than `check_types` is provided.
"""
if not callable(func):
raise TypeError("func must be callable, got: %s" % func)
if not structures:
raise ValueError("Must provide at least one structure")
check_types = kwargs.pop("check_types", True)
if kwargs:
raise ValueError(
"Only valid keyword arguments are `check_types` "
"not: `%s`" % ("`, `".join(kwargs.keys())))
for other in structures[1:]:
assert_same_structure(structures[0], other, check_types=check_types)
return unflatten_as(structures[0],
[func(*args) for args in zip(*map(flatten, structures))])
def map_structure_with_path(func, *structures, **kwargs):
"""Maps `func` through given structures.
This is a variant of :func:`~tree.map_structure` which accumulates
a *path* while mapping through the structures. A path is a tuple of
indices and/or keys which uniquely identifies the positions of the
arguments passed to `func`.
>>> tree.map_structure_with_path(
... lambda path, v: (path, v**2),
... [{"foo": 42}])
[{'foo': ((0, 'foo'), 1764)}]
Args:
func: A callable that accepts a path and as many arguments as there are
structures.
*structures: Arbitrarily nested structures of the same layout.
**kwargs: The only valid keyword argument is `check_types`. If `True`
(default) the types of components within the structures have to be match,
e.g. ``tree.map_structure_with_path(func, [1], (1,))`` will raise a
`TypeError`, otherwise this is not enforced. Note that namedtuples with
identical name and fields are considered to be the same type.
Returns:
A new structure with the same layout as the given ones. If the
`structures` have components of varying types, the resulting structure
will use the same types as ``structures[0]``.
Raises:
TypeError: If `func` is not callable or if the `structures` do not
have the same layout.
TypeError: If `check_types` is `True` and any two `structures`
differ in the types of their components.
ValueError: If no structures were given or if a keyword argument other
than `check_types` is provided.
"""
return map_structure_with_path_up_to(structures[0], func, *structures,
**kwargs)
def _yield_flat_up_to(shallow_tree, input_tree, path=()):
"""Yields (path, value) pairs of input_tree flattened up to shallow_tree.
Args:
shallow_tree: Nested structure. Traverse no further than its leaf nodes.
input_tree: Nested structure. Return the paths and values from this tree.
Must have the same upper structure as shallow_tree.
path: Tuple. Optional argument, only used when recursing. The path from the
root of the original shallow_tree, down to the root of the shallow_tree
arg of this recursive call.
Yields:
Pairs of (path, value), where path the tuple path of a leaf node in
shallow_tree, and value is the value of the corresponding node in
input_tree.
"""
if (isinstance(shallow_tree, _TEXT_OR_BYTES) or
not (isinstance(shallow_tree, (collections_abc.Mapping,
collections_abc.Sequence)) or
_is_namedtuple(shallow_tree) or
_is_attrs(shallow_tree))):
yield (path, input_tree)
else:
input_tree = dict(_yield_sorted_items(input_tree))
for shallow_key, shallow_subtree in _yield_sorted_items(shallow_tree):
subpath = path + (shallow_key,)
input_subtree = input_tree[shallow_key]
for leaf_path, leaf_value in _yield_flat_up_to(shallow_subtree,
input_subtree,
path=subpath):
yield (leaf_path, leaf_value)
def _multiyield_flat_up_to(shallow_tree, *input_trees):
"""Same as `_yield_flat_up_to`, but takes multiple input trees."""
zipped_iterators = zip(*[_yield_flat_up_to(shallow_tree, input_tree)
for input_tree in input_trees])
try:
for paths_and_values in zipped_iterators:
paths, values = zip(*paths_and_values)
yield paths[:1] + values
except KeyError as e:
paths = locals().get("paths", ((),))
raise ValueError(f"Could not find key '{e.args[0]}' in some `input_trees`. "
"Please ensure the structure of all `input_trees` are "
"compatible with `shallow_tree`. The last valid path "
f"yielded was {paths[0]}.") from e
def _assert_shallow_structure(shallow_tree,
input_tree,
path=None,
check_types=True):
"""Asserts that `shallow_tree` is a shallow structure of `input_tree`.
That is, this function recursively tests if each key in shallow_tree has its
corresponding key in input_tree.
Examples:
The following code will raise an exception:
>>> shallow_tree = {"a": "A", "b": "B"}
>>> input_tree = {"a": 1, "c": 2}
>>> _assert_shallow_structure(shallow_tree, input_tree)
Traceback (most recent call last):
...
ValueError: The shallow_tree's keys are not a subset of the input_tree's ...
The following code will raise an exception:
>>> shallow_tree = ["a", "b"]
>>> input_tree = ["c", ["d", "e"], "f"]
>>> _assert_shallow_structure(shallow_tree, input_tree)
Traceback (most recent call last):
...
ValueError: The two structures don't have the same sequence length. ...
By setting check_types=False, we drop the requirement that corresponding
nodes in shallow_tree and input_tree have to be the same type. Sequences
are treated equivalently to Mappables that map integer keys (indices) to
values. The following code will therefore not raise an exception:
>>> _assert_shallow_structure({0: "foo"}, ["foo"], check_types=False)
Args:
shallow_tree: an arbitrarily nested structure.
input_tree: an arbitrarily nested structure.
path: if not `None`, a tuple containing the current path in the nested
structure. This is only used for more informative errror messages.
check_types: if `True` (default) the sequence types of `shallow_tree` and
`input_tree` have to be the same.
Raises:
TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
TypeError: If the sequence types of `shallow_tree` are different from
`input_tree`. Only raised if `check_types` is `True`.
ValueError: If the sequence lengths of `shallow_tree` are different from
`input_tree`.
"""
if is_nested(shallow_tree):
if not is_nested(input_tree):
if path is not None:
raise TypeError(
_IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format(
path=list(path), input_type=type(input_tree)))
else:
raise TypeError(
_IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(
type(input_tree)))
if isinstance(shallow_tree, ObjectProxy):
shallow_type = type(shallow_tree.__wrapped__)
else:
shallow_type = type(shallow_tree)
if check_types and not isinstance(input_tree, shallow_type):
# Duck-typing means that nest should be fine with two different
# namedtuples with identical name and fields.
shallow_is_namedtuple = _is_namedtuple(shallow_tree, False)
input_is_namedtuple = _is_namedtuple(input_tree, False)
if shallow_is_namedtuple and input_is_namedtuple:
# pylint: disable=protected-access
if not _tree.same_namedtuples(shallow_tree, input_tree):
raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format(
input_type=type(input_tree),
shallow_type=shallow_type))
# pylint: enable=protected-access
elif not (isinstance(shallow_tree, collections_abc.Mapping)
and isinstance(input_tree, collections_abc.Mapping)):
raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format(
input_type=type(input_tree),
shallow_type=shallow_type))
if _num_elements(input_tree) != _num_elements(shallow_tree):
raise ValueError(
_STRUCTURES_HAVE_MISMATCHING_LENGTHS.format(
input_length=_num_elements(input_tree),
shallow_length=_num_elements(shallow_tree)))
elif _num_elements(input_tree) < _num_elements(shallow_tree):
raise ValueError(
_INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format(
input_size=_num_elements(input_tree),
shallow_size=_num_elements(shallow_tree)))
shallow_iter = _yield_sorted_items(shallow_tree)
input_iter = _yield_sorted_items(input_tree)
def get_matching_input_branch(shallow_key):
for input_key, input_branch in input_iter:
if input_key == shallow_key:
return input_branch
raise ValueError(_SHALLOW_TREE_HAS_INVALID_KEYS.format([shallow_key]))
for shallow_key, shallow_branch in shallow_iter:
input_branch = get_matching_input_branch(shallow_key)
_assert_shallow_structure(
shallow_branch,
input_branch,
path + (shallow_key,) if path is not None else None,
check_types=check_types)
def flatten_up_to(shallow_structure, input_structure, check_types=True):
"""Flattens `input_structure` up to `shallow_structure`.
All further nested components in `input_structure` are retained as-is.
>>> structure = [[1, 1], [2, 2]]
>>> tree.flatten_up_to([None, None], structure)
[[1, 1], [2, 2]]
>>> tree.flatten_up_to([None, [None, None]], structure)
[[1, 1], 2, 2]
If `shallow_structure` and `input_structure` are not nested, the
result is a single-element list:
>>> tree.flatten_up_to(42, 1)
[1]
>>> tree.flatten_up_to(42, [1, 2, 3])
[[1, 2, 3]]
Args:
shallow_structure: A structure with the same (but possibly more shallow)
layout as `input_structure`.
input_structure: An arbitrarily nested structure.
check_types: If `True`, check that each node in shallow_tree has the
same type as the corresponding node in `input_structure`.
Returns:
A list, the partially flattened version of `input_structure` wrt
`shallow_structure`.
Raises:
TypeError: If the layout of `shallow_structure` does not match that of
`input_structure`.
TypeError: If `check_types` is `True` and `shallow_structure` and
`input_structure` differ in the types of their components.
"""
_assert_shallow_structure(
shallow_structure, input_structure, path=None, check_types=check_types)
# Discard paths returned by _yield_flat_up_to.
return [v for _, v in _yield_flat_up_to(shallow_structure, input_structure)]
def flatten_with_path_up_to(shallow_structure,
input_structure,
check_types=True):
"""Flattens `input_structure` up to `shallow_structure`.
This is a combination of :func:`~tree.flatten_up_to` and
:func:`~tree.flatten_with_path`
Args:
shallow_structure: A structure with the same (but possibly more shallow)
layout as `input_structure`.
input_structure: An arbitrarily nested structure.
check_types: If `True`, check that each node in shallow_tree has the
same type as the corresponding node in `input_structure`.
Returns:
A list of ``(path, item)`` pairs corresponding to the partially flattened
version of `input_structure` wrt `shallow_structure`.
Raises:
TypeError: If the layout of `shallow_structure` does not match that of
`input_structure`.
TypeError: If `input_structure` is or contains a mapping with non-sortable
keys.
TypeError: If `check_types` is `True` and `shallow_structure` and
`input_structure` differ in the types of their components.
"""
_assert_shallow_structure(
shallow_structure, input_structure, path=(), check_types=check_types)
return list(_yield_flat_up_to(shallow_structure, input_structure))
def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
"""Maps `func` through given structures up to `shallow_structure`.
This is a variant of :func:`~tree.map_structure` which only maps
the given structures up to `shallow_structure`. All further nested
components are retained as-is.
>>> structure = [[1, 1], [2, 2]]
>>> tree.map_structure_up_to([None, None], len, structure)
[2, 2]
>>> tree.map_structure_up_to([None, [None, None]], str, structure)
['[1, 1]', ['2', '2']]
Args:
shallow_structure: A structure with layout common to all `structures`.
func: A callable that accepts as many arguments as there are structures.
*structures: Arbitrarily nested structures of the same layout.
**kwargs: No valid keyword arguments.
Raises:
ValueError: If `func` is not callable or if `structures` have different
layout or if the layout of `shallow_structure` does not match that of
`structures` or if no structures were given.
Returns:
A new structure with the same layout as `shallow_structure`.
"""
return map_structure_with_path_up_to(
shallow_structure,
lambda _, *args: func(*args), # Discards path.
*structures,
**kwargs)
def map_structure_with_path_up_to(shallow_structure, func, *structures,
**kwargs):
"""Maps `func` through given structures up to `shallow_structure`.
This is a combination of :func:`~tree.map_structure_up_to` and
:func:`~tree.map_structure_with_path`
Args:
shallow_structure: A structure with layout common to all `structures`.
func: A callable that accepts a path and as many arguments as there are
structures.
*structures: Arbitrarily nested structures of the same layout.
**kwargs: No valid keyword arguments.
Raises:
ValueError: If `func` is not callable or if `structures` have different
layout or if the layout of `shallow_structure` does not match that of
`structures` or if no structures were given.
Returns:
Result of repeatedly applying `func`. Has the same structure layout
as `shallow_tree`.
"""
if "check_types" in kwargs:
logging.warning("The use of `check_types` is deprecated and does not have "
"any effect.")
del kwargs
results = []
for path_and_values in _multiyield_flat_up_to(shallow_structure, *structures):
results.append(func(*path_and_values))
return unflatten_as(shallow_structure, results)
def flatten_with_path(structure):
r"""Flattens a possibly nested structure into a list.
This is a variant of :func:`~tree.flattens` which produces a list of
pairs: ``(path, item)``. A path is a tuple of indices and/or keys
which uniquely identifies the position of the corresponding ``item``.
>>> tree.flatten_with_path([{"foo": 42}])
[((0, 'foo'), 42)]
Args:
structure: An arbitrarily nested structure.
Returns:
A list of ``(path, item)`` pairs corresponding to the flattened version
of the input `structure`.
Raises:
TypeError:
If ``structure`` is or contains a mapping with non-sortable keys.
"""
return list(_yield_flat_up_to(structure, structure))
#: Special value for use with :func:`traverse`.
MAP_TO_NONE = object()
def traverse(fn, structure, top_down=True):
"""Traverses the given nested structure, applying the given function.
The traversal is depth-first. If ``top_down`` is True (default), parents
are returned before their children (giving the option to avoid traversing
into a sub-tree).
>>> visited = []
>>> tree.traverse(visited.append, [(1, 2), [3], {"a": 4}], top_down=True)
[(1, 2), [3], {'a': 4}]
>>> visited
[[(1, 2), [3], {'a': 4}], (1, 2), 1, 2, [3], 3, {'a': 4}, 4]
>>> visited = []
>>> tree.traverse(visited.append, [(1, 2), [3], {"a": 4}], top_down=False)
[(1, 2), [3], {'a': 4}]
>>> visited
[1, 2, (1, 2), 3, [3], 4, {'a': 4}, [(1, 2), [3], {'a': 4}]]
Args:
fn: The function to be applied to each sub-nest of the structure.
When traversing top-down:
If ``fn(subtree) is None`` the traversal continues into the sub-tree.
If ``fn(subtree) is not None`` the traversal does not continue into
the sub-tree. The sub-tree will be replaced by ``fn(subtree)`` in the
returned structure (to replace the sub-tree with None, use the special
value :data:`MAP_TO_NONE`).
When traversing bottom-up:
If ``fn(subtree) is None`` the traversed sub-tree is returned unaltered.
If ``fn(subtree) is not None`` the sub-tree will be replaced by
``fn(subtree)`` in the returned structure (to replace the sub-tree
with None, use the special value :data:`MAP_TO_NONE`).
structure: The structure to traverse.
top_down: If True, parent structures will be visited before their children.
Returns:
The structured output from the traversal.
"""
return traverse_with_path(lambda _, x: fn(x), structure, top_down=top_down)
def traverse_with_path(fn, structure, top_down=True):
"""Traverses the given nested structure, applying the given function.
The traversal is depth-first. If ``top_down`` is True (default), parents
are returned before their children (giving the option to avoid traversing
into a sub-tree).
>>> visited = []
>>> tree.traverse_with_path(
... lambda path, subtree: visited.append((path, subtree)),
... [(1, 2), [3], {"a": 4}],
... top_down=True)
[(1, 2), [3], {'a': 4}]
>>> visited == [
... ((), [(1, 2), [3], {'a': 4}]),
... ((0,), (1, 2)),
... ((0, 0), 1),
... ((0, 1), 2),
... ((1,), [3]),
... ((1, 0), 3),
... ((2,), {'a': 4}),
... ((2, 'a'), 4)]
True
>>> visited = []
>>> tree.traverse_with_path(
... lambda path, subtree: visited.append((path, subtree)),
... [(1, 2), [3], {"a": 4}],
... top_down=False)
[(1, 2), [3], {'a': 4}]
>>> visited == [
... ((0, 0), 1),
... ((0, 1), 2),
... ((0,), (1, 2)),
... ((1, 0), 3),
... ((1,), [3]),
... ((2, 'a'), 4),
... ((2,), {'a': 4}),
... ((), [(1, 2), [3], {'a': 4}])]
True
Args:
fn: The function to be applied to the path to each sub-nest of the structure
and the sub-nest value.
When traversing top-down: If ``fn(path, subtree) is None`` the traversal
continues into the sub-tree. If ``fn(path, subtree) is not None`` the
traversal does not continue into the sub-tree. The sub-tree will be
replaced by ``fn(path, subtree)`` in the returned structure (to replace
the sub-tree with None, use the special
value :data:`MAP_TO_NONE`).
When traversing bottom-up: If ``fn(path, subtree) is None`` the traversed
sub-tree is returned unaltered. If ``fn(path, subtree) is not None`` the
sub-tree will be replaced by ``fn(path, subtree)`` in the returned
structure (to replace the sub-tree
with None, use the special value :data:`MAP_TO_NONE`).
structure: The structure to traverse.
top_down: If True, parent structures will be visited before their children.
Returns:
The structured output from the traversal.
"""
def traverse_impl(path, structure):
"""Recursive traversal implementation."""
def subtree_fn(item):
subtree_path, subtree = item
return traverse_impl(path + (subtree_path,), subtree)
def traverse_subtrees():
if is_nested(structure):
return _sequence_like(structure,
map(subtree_fn, _yield_sorted_items(structure)))
else:
return structure
if top_down:
ret = fn(path, structure)
if ret is None:
return traverse_subtrees()
elif ret is MAP_TO_NONE:
return None
else:
return ret
else:
traversed_structure = traverse_subtrees()
ret = fn(path, traversed_structure)
if ret is None:
return traversed_structure
elif ret is MAP_TO_NONE:
return None
else:
return ret
return traverse_impl((), structure)
tree-0.1.9/tree/sequence.py 0000664 0000000 0000000 00000010032 14746732372 0015614 0 ustar 00root root 0000000 0000000 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains _sequence_like and helpers for sequence data structures."""
import collections
from collections import abc as collections_abc
import types
from tree import _tree
# pylint: disable=g-import-not-at-top
try:
import wrapt
ObjectProxy = wrapt.ObjectProxy
except ImportError:
class ObjectProxy(object):
"""Stub-class for `wrapt.ObjectProxy``."""
def _sorted(dictionary):
"""Returns a sorted list of the dict keys, with error if keys not sortable."""
try:
return sorted(dictionary)
except TypeError:
raise TypeError("tree only supports dicts with sortable keys.")
def _is_attrs(instance):
return _tree.is_attrs(instance)
def _is_namedtuple(instance, strict=False):
"""Returns True iff `instance` is a `namedtuple`.
Args:
instance: An instance of a Python object.
strict: If True, `instance` is considered to be a `namedtuple` only if
it is a "plain" namedtuple. For instance, a class inheriting
from a `namedtuple` will be considered to be a `namedtuple`
iff `strict=False`.
Returns:
True if `instance` is a `namedtuple`.
"""
return _tree.is_namedtuple(instance, strict)
def _sequence_like(instance, args):
"""Converts the sequence `args` to the same type as `instance`.
Args:
instance: an instance of `tuple`, `list`, `namedtuple`, `dict`, or
`collections.OrderedDict`.
args: elements to be converted to the `instance` type.
Returns:
`args` with the type of `instance`.
"""
if isinstance(instance, (dict, collections_abc.Mapping)):
# Pack dictionaries in a deterministic order by sorting the keys.
# Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
# ordered and plain dicts (e.g., flattening a dict but using a
# corresponding `OrderedDict` to pack it back).
result = dict(zip(_sorted(instance), args))
keys_and_values = ((key, result[key]) for key in instance)
if isinstance(instance, collections.defaultdict):
# `defaultdict` requires a default factory as the first argument.
return type(instance)(instance.default_factory, keys_and_values)
elif isinstance(instance, types.MappingProxyType):
# MappingProxyType requires a dict to proxy to.
return type(instance)(dict(keys_and_values))
else:
return type(instance)(keys_and_values)
elif isinstance(instance, collections_abc.MappingView):
# We can't directly construct mapping views, so we create a list instead
return list(args)
elif _is_namedtuple(instance) or _is_attrs(instance):
if isinstance(instance, ObjectProxy):
instance_type = type(instance.__wrapped__)
else:
instance_type = type(instance)
try:
if _is_attrs(instance):
return instance_type(
**{
attr.name: arg
for attr, arg in zip(instance_type.__attrs_attrs__, args)
})
else:
return instance_type(*args)
except Exception as e:
raise TypeError(
f"Couldn't traverse {instance!r} with arguments {args}") from e
elif isinstance(instance, ObjectProxy):
# For object proxies, first create the underlying type and then re-wrap it
# in the proxy type.
return type(instance)(_sequence_like(instance.__wrapped__, args))
else:
# Not a namedtuple
return type(instance)(args)
tree-0.1.9/tree/tree.cc 0000664 0000000 0000000 00000060364 14746732372 0014715 0 ustar 00root root 0000000 0000000 /* Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tree.h"
#include
#include
#include
#include
// logging
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include
#ifdef LOG
#define LOG_WARNING(w) LOG(WARNING) << w;
#else
#include
#define LOG_WARNING(w) std::cerr << w << "\n";
#endif
#ifndef DCHECK
#define DCHECK(stmt)
#endif
namespace py = pybind11;
namespace tree {
namespace {
// PyObjectPtr wraps an underlying Python object and decrements the
// reference count in the destructor.
//
// This class does not acquire the GIL in the destructor, so the GIL must be
// held when the destructor is called.
using PyObjectPtr = std::unique_ptr;
const int kMaxItemsInCache = 1024;
bool WarnedThatSetIsNotSequence = false;
bool IsString(PyObject* o) {
return PyBytes_Check(o) || PyByteArray_Check(o) || PyUnicode_Check(o);
}
// Equivalent to Python's 'o.__class__.__name__'
// Note that '__class__' attribute is set only in new-style classes.
// A lot of tensorflow code uses __class__ without checks, so it seems like
// we only support new-style classes.
absl::string_view GetClassName(PyObject* o) {
// __class__ is equivalent to type() for new style classes.
// type() is equivalent to PyObject_Type()
// (https://docs.python.org/3.5/c-api/object.html#c.PyObject_Type)
// PyObject_Type() is equivalent to o->ob_type except for Py_INCREF, which
// we don't need here.
PyTypeObject* type = o->ob_type;
// __name__ is the value of `tp_name` after the last '.'
// (https://docs.python.org/2/c-api/typeobj.html#c.PyTypeObject.tp_name)
absl::string_view name(type->tp_name);
size_t pos = name.rfind('.');
if (pos != absl::string_view::npos) {
name.remove_prefix(pos + 1);
}
return name;
}
std::string PyObjectToString(PyObject* o) {
if (o == nullptr) {
return "";
}
PyObject* str = PyObject_Str(o);
if (str) {
std::string s(PyUnicode_AsUTF8(str));
Py_DECREF(str);
return absl::StrCat("type=", GetClassName(o), " str=", s);
} else {
return "";
}
}
class CachedTypeCheck {
public:
explicit CachedTypeCheck(std::function ternary_predicate)
: ternary_predicate_(std::move(ternary_predicate)) {}
~CachedTypeCheck() {
for (const auto& pair : type_to_sequence_map_) {
Py_DECREF(pair.first);
}
}
// Caches successful executions of the one-argument (PyObject*) callable
// "ternary_predicate" based on the type of "o". -1 from the callable
// indicates an unsuccessful check (not cached), 0 indicates that "o"'s type
// does not match the predicate, and 1 indicates that it does. Used to avoid
// calling back into Python for expensive isinstance checks.
int CachedLookup(PyObject* o) {
// Try not to return to Python - see if the type has already been seen
// before.
auto* type = Py_TYPE(o);
{
auto it = type_to_sequence_map_.find(type);
if (it != type_to_sequence_map_.end()) {
return it->second;
}
}
int check_result = ternary_predicate_(o);
if (check_result == -1) {
return -1; // Type check error, not cached.
}
// NOTE: This is never decref'd as long as the object lives, which is likely
// forever, but we don't want the type to get deleted as long as it is in
// the map. This should not be too much of a leak, as there should only be a
// relatively small number of types in the map, and an even smaller number
// that are eligible for decref. As a precaution, we limit the size of the
// map to 1024.
{
if (type_to_sequence_map_.size() < kMaxItemsInCache) {
Py_INCREF(type);
type_to_sequence_map_.insert({type, check_result});
}
}
return check_result;
}
private:
std::function ternary_predicate_;
std::unordered_map type_to_sequence_map_;
};
py::object GetCollectionsSequenceType() {
static py::object type =
py::module::import("collections.abc").attr("Sequence");
return type;
}
py::object GetCollectionsMappingType() {
static py::object type =
py::module::import("collections.abc").attr("Mapping");
return type;
}
py::object GetCollectionsMappingViewType() {
static py::object type =
py::module::import("collections.abc").attr("MappingView");
return type;
}
py::object GetWraptObjectProxyTypeUncached() {
try {
return py::module::import("wrapt").attr("ObjectProxy");
} catch (const py::error_already_set& e) {
if (e.matches(PyExc_ImportError)) return py::none();
throw e;
}
}
py::object GetWraptObjectProxyType() {
// TODO(gregthornton): Restore caching when deadlock issue is fixed.
return GetWraptObjectProxyTypeUncached();
}
// Returns 1 if `o` is considered a mapping for the purposes of Flatten().
// Returns 0 otherwise.
// Returns -1 if an error occurred.
int IsMappingHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
return PyObject_IsInstance(to_check, GetCollectionsMappingType().ptr());
});
if (PyDict_Check(o)) return true;
return check_cache->CachedLookup(o);
}
// Returns 1 if `o` is considered a mapping view for the purposes of Flatten().
// Returns 0 otherwise.
// Returns -1 if an error occurred.
int IsMappingViewHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
return PyObject_IsInstance(to_check, GetCollectionsMappingViewType().ptr());
});
return check_cache->CachedLookup(o);
}
// Returns 1 if `o` is considered an object proxy
// Returns 0 otherwise.
// Returns -1 if an error occurred.
int IsObjectProxy(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
auto type = GetWraptObjectProxyType();
return !type.is_none() && PyObject_IsInstance(to_check, type.ptr()) == 1;
});
return check_cache->CachedLookup(o);
}
// Returns 1 if `o` is an instance of attrs-decorated class.
// Returns 0 otherwise.
int IsAttrsHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
PyObjectPtr cls(PyObject_GetAttrString(to_check, "__class__"));
if (cls) {
return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
}
// PyObject_GetAttrString returns null on error
PyErr_Clear();
return 0;
});
return check_cache->CachedLookup(o);
}
// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
// Returns 0 otherwise.
// Returns -1 if an error occurred.
int IsSequenceHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
int is_instance =
PyObject_IsInstance(to_check, GetCollectionsSequenceType().ptr());
// Don't cache a failed is_instance check.
if (is_instance == -1) return -1;
return static_cast(is_instance != 0 && !IsString(to_check));
}); // We treat dicts and other mappings as special cases of sequences.
if (IsMappingHelper(o)) return true;
if (IsMappingViewHelper(o)) return true;
if (IsAttrsHelper(o)) return true;
if (PySet_Check(o) && !WarnedThatSetIsNotSequence) {
LOG_WARNING(
"Sets are not currently considered sequences, "
"but this may change in the future, "
"so consider avoiding using them.");
WarnedThatSetIsNotSequence = true;
}
return check_cache->CachedLookup(o);
}
using ValueIteratorPtr = std::unique_ptr;
// Iterate through dictionaries in a deterministic order by sorting the
// keys. Notice this means that we ignore the original order of
// `OrderedDict` instances. This is intentional, to avoid potential
// bugs caused by mixing ordered and plain dicts (e.g., flattening
// a dict but using a corresponding `OrderedDict` to pack it back).
class DictValueIterator : public ValueIterator {
public:
explicit DictValueIterator(PyObject* dict)
: dict_(dict), keys_(PyDict_Keys(dict)) {
if (PyList_Sort(keys_.get()) == -1) {
invalidate();
} else {
iter_.reset(PyObject_GetIter(keys_.get()));
}
}
PyObjectPtr next() override {
PyObjectPtr result;
PyObjectPtr key(PyIter_Next(iter_.get()));
if (key) {
// PyDict_GetItem returns a borrowed reference.
PyObject* elem = PyDict_GetItem(dict_, key.get());
if (elem) {
Py_INCREF(elem);
result.reset(elem);
} else {
PyErr_SetString(PyExc_RuntimeError,
"Dictionary was modified during iteration over it");
}
}
return result;
}
private:
PyObject* dict_;
PyObjectPtr keys_;
PyObjectPtr iter_;
};
// Iterate over mapping objects by sorting the keys first
class MappingValueIterator : public ValueIterator {
public:
explicit MappingValueIterator(PyObject* mapping)
: mapping_(mapping), keys_(PyMapping_Keys(mapping)) {
if (!keys_ || PyList_Sort(keys_.get()) == -1) {
invalidate();
} else {
iter_.reset(PyObject_GetIter(keys_.get()));
}
}
PyObjectPtr next() override {
PyObjectPtr result;
PyObjectPtr key(PyIter_Next(iter_.get()));
if (key) {
// Unlike PyDict_GetItem, PyObject_GetItem returns a new reference.
PyObject* elem = PyObject_GetItem(mapping_, key.get());
if (elem) {
result.reset(elem);
} else {
PyErr_SetString(PyExc_RuntimeError,
"Mapping was modified during iteration over it");
}
}
return result;
}
private:
PyObject* mapping_;
PyObjectPtr keys_;
PyObjectPtr iter_;
};
// Iterate over a sequence, by index.
class SequenceValueIterator : public ValueIterator {
public:
explicit SequenceValueIterator(PyObject* iterable)
: seq_(PySequence_Fast(iterable, "")),
size_(seq_.get() ? PySequence_Fast_GET_SIZE(seq_.get()) : 0),
index_(0) {}
PyObjectPtr next() override {
PyObjectPtr result;
if (index_ < size_) {
// PySequence_Fast_GET_ITEM returns a borrowed reference.
PyObject* elem = PySequence_Fast_GET_ITEM(seq_.get(), index_);
++index_;
if (elem) {
Py_INCREF(elem);
result.reset(elem);
}
}
return result;
}
private:
PyObjectPtr seq_;
const Py_ssize_t size_;
Py_ssize_t index_;
};
class AttrsValueIterator : public ValueIterator {
public:
explicit AttrsValueIterator(PyObject* nested) : nested_(nested) {
Py_INCREF(nested);
cls_.reset(PyObject_GetAttrString(nested_.get(), "__class__"));
if (cls_) {
attrs_.reset(PyObject_GetAttrString(cls_.get(), "__attrs_attrs__"));
if (attrs_) {
iter_.reset(PyObject_GetIter(attrs_.get()));
}
}
if (!iter_ || PyErr_Occurred()) invalidate();
}
PyObjectPtr next() override {
PyObjectPtr result;
PyObjectPtr item(PyIter_Next(iter_.get()));
if (item) {
PyObjectPtr name(PyObject_GetAttrString(item.get(), "name"));
result.reset(PyObject_GetAttr(nested_.get(), name.get()));
}
return result;
}
private:
PyObjectPtr nested_;
PyObjectPtr cls_;
PyObjectPtr attrs_;
PyObjectPtr iter_;
};
bool FlattenHelper(
PyObject* nested, PyObject* list,
const std::function& is_sequence_helper,
const std::function& value_iterator_getter) {
// if nested is not a sequence, append itself and exit
int is_seq = is_sequence_helper(nested);
if (is_seq == -1) return false;
if (!is_seq) {
return PyList_Append(list, nested) != -1;
}
ValueIteratorPtr iter = value_iterator_getter(nested);
if (!iter->valid()) return false;
for (PyObjectPtr item = iter->next(); item; item = iter->next()) {
if (Py_EnterRecursiveCall(" in flatten")) {
return false;
}
const bool success = FlattenHelper(item.get(), list, is_sequence_helper,
value_iterator_getter);
Py_LeaveRecursiveCall();
if (!success) {
return false;
}
}
return true;
}
// Sets error using keys of 'dict1' and 'dict2'.
// 'dict1' and 'dict2' are assumed to be Python dictionaries.
void SetDifferentKeysError(PyObject* dict1, PyObject* dict2,
std::string* error_msg, bool* is_type_error) {
PyObjectPtr k1(PyMapping_Keys(dict1));
if (PyErr_Occurred() || k1.get() == nullptr) {
*error_msg =
("The two dictionaries don't have the same set of keys. Failed to "
"fetch keys.");
return;
}
PyObjectPtr k2(PyMapping_Keys(dict2));
if (PyErr_Occurred() || k2.get() == nullptr) {
*error_msg =
("The two dictionaries don't have the same set of keys. Failed to "
"fetch keys.");
return;
}
*is_type_error = false;
*error_msg = absl::StrCat(
"The two dictionaries don't have the same set of keys. "
"First structure has keys ",
PyObjectToString(k1.get()), ", while second structure has keys ",
PyObjectToString(k2.get()));
}
// Returns true iff there were no "internal" errors. In other words,
// errors that has nothing to do with structure checking.
// If an "internal" error occurred, the appropriate Python error will be
// set and the caller can propage it directly to the user.
//
// Both `error_msg` and `is_type_error` must be non-null. `error_msg` must
// be empty.
// Leaves `error_msg` empty if structures matched. Else, fills `error_msg`
// with appropriate error and sets `is_type_error` to true iff
// the error to be raised should be TypeError.
bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types,
std::string* error_msg, bool* is_type_error) {
DCHECK(error_msg);
DCHECK(is_type_error);
const bool is_seq1 = IsSequence(o1);
const bool is_seq2 = IsSequence(o2);
if (PyErr_Occurred()) return false;
if (is_seq1 != is_seq2) {
std::string seq_str = is_seq1 ? PyObjectToString(o1) : PyObjectToString(o2);
std::string non_seq_str =
is_seq1 ? PyObjectToString(o2) : PyObjectToString(o1);
*is_type_error = false;
*error_msg = absl::StrCat("Substructure \"", seq_str,
"\" is a sequence, while substructure \"",
non_seq_str, "\" is not");
return true;
}
// Got to scalars, so finished checking. Structures are the same.
if (!is_seq1) return true;
if (check_types) {
// Unwrap wrapt.ObjectProxy if needed.
PyObjectPtr o1_wrapped;
if (IsObjectProxy(o1)) {
o1_wrapped.reset(PyObject_GetAttrString(o1, "__wrapped__"));
o1 = o1_wrapped.get();
}
PyObjectPtr o2_wrapped;
if (IsObjectProxy(o2)) {
o2_wrapped.reset(PyObject_GetAttrString(o2, "__wrapped__"));
o2 = o2_wrapped.get();
}
const PyTypeObject* type1 = o1->ob_type;
const PyTypeObject* type2 = o2->ob_type;
// We treat two different namedtuples with identical name and fields
// as having the same type.
const PyObject* o1_tuple = IsNamedtuple(o1, true);
if (o1_tuple == nullptr) return false;
const PyObject* o2_tuple = IsNamedtuple(o2, true);
if (o2_tuple == nullptr) {
Py_DECREF(o1_tuple);
return false;
}
bool both_tuples = o1_tuple == Py_True && o2_tuple == Py_True;
Py_DECREF(o1_tuple);
Py_DECREF(o2_tuple);
if (both_tuples) {
const PyObject* same_tuples = SameNamedtuples(o1, o2);
if (same_tuples == nullptr) return false;
bool not_same_tuples = same_tuples != Py_True;
Py_DECREF(same_tuples);
if (not_same_tuples) {
*is_type_error = true;
*error_msg = absl::StrCat(
"The two namedtuples don't have the same sequence type. "
"First structure ",
PyObjectToString(o1), " has type ", type1->tp_name,
", while second structure ", PyObjectToString(o2), " has type ",
type2->tp_name);
return true;
}
} else if (type1 != type2
/* If both sequences are list types, don't complain. This allows
one to be a list subclass (e.g. _ListWrapper used for
automatic dependency tracking.) */
&& !(PyList_Check(o1) && PyList_Check(o2))
/* Two mapping types will also compare equal, making _DictWrapper
and dict compare equal. */
&& !(IsMappingHelper(o1) && IsMappingHelper(o2))) {
*is_type_error = true;
*error_msg = absl::StrCat(
"The two namedtuples don't have the same sequence type. "
"First structure ",
PyObjectToString(o1), " has type ", type1->tp_name,
", while second structure ", PyObjectToString(o2), " has type ",
type2->tp_name);
return true;
}
if (PyDict_Check(o1) && PyDict_Check(o2)) {
if (PyDict_Size(o1) != PyDict_Size(o2)) {
SetDifferentKeysError(o1, o2, error_msg, is_type_error);
return true;
}
PyObject* key;
Py_ssize_t pos = 0;
while (PyDict_Next(o1, &pos, &key, nullptr)) {
if (PyDict_GetItem(o2, key) == nullptr) {
SetDifferentKeysError(o1, o2, error_msg, is_type_error);
return true;
}
}
} else if (IsMappingHelper(o1)) {
// Fallback for custom mapping types. Instead of using PyDict methods
// which stay in C, we call iter(o1).
if (PyMapping_Size(o1) != PyMapping_Size(o2)) {
SetDifferentKeysError(o1, o2, error_msg, is_type_error);
return true;
}
PyObjectPtr iter(PyObject_GetIter(o1));
PyObject* key;
while ((key = PyIter_Next(iter.get())) != nullptr) {
if (!PyMapping_HasKey(o2, key)) {
SetDifferentKeysError(o1, o2, error_msg, is_type_error);
Py_DECREF(key);
return true;
}
Py_DECREF(key);
}
}
}
ValueIteratorPtr iter1 = GetValueIterator(o1);
ValueIteratorPtr iter2 = GetValueIterator(o2);
if (!iter1->valid() || !iter2->valid()) return false;
while (true) {
PyObjectPtr v1 = iter1->next();
PyObjectPtr v2 = iter2->next();
if (v1 && v2) {
if (Py_EnterRecursiveCall(" in assert_same_structure")) {
return false;
}
bool no_internal_errors = AssertSameStructureHelper(
v1.get(), v2.get(), check_types, error_msg, is_type_error);
Py_LeaveRecursiveCall();
if (!no_internal_errors) return false;
if (!error_msg->empty()) return true;
} else if (!v1 && !v2) {
// Done with all recursive calls. Structure matched.
return true;
} else {
*is_type_error = false;
*error_msg = absl::StrCat(
"The two structures don't have the same number of elements. ",
"First structure: ", PyObjectToString(o1),
". Second structure: ", PyObjectToString(o2));
return true;
}
}
}
} // namespace
bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
PyObject* Flatten(PyObject* nested) {
PyObject* list = PyList_New(0);
if (FlattenHelper(nested, list, IsSequenceHelper, GetValueIterator)) {
return list;
} else {
Py_DECREF(list);
return nullptr;
}
}
PyObject* IsNamedtuple(PyObject* o, bool strict) {
// Unwrap wrapt.ObjectProxy if needed.
PyObjectPtr o_wrapped;
if (IsObjectProxy(o)) {
o_wrapped.reset(PyObject_GetAttrString(o, "__wrapped__"));
o = o_wrapped.get();
}
// Must be subclass of tuple
if (!PyTuple_Check(o)) {
Py_RETURN_FALSE;
}
// If strict, o.__class__.__base__ must be tuple
if (strict) {
PyObject* klass = PyObject_GetAttrString(o, "__class__");
if (klass == nullptr) return nullptr;
PyObject* base = PyObject_GetAttrString(klass, "__base__");
Py_DECREF(klass);
if (base == nullptr) return nullptr;
const PyTypeObject* base_type = reinterpret_cast(base);
// built-in object types are singletons
bool tuple_base = base_type == &PyTuple_Type;
Py_DECREF(base);
if (!tuple_base) {
Py_RETURN_FALSE;
}
}
// o must have attribute '_fields' and every element in
// '_fields' must be a string.
int has_fields = PyObject_HasAttrString(o, "_fields");
if (!has_fields) {
Py_RETURN_FALSE;
}
PyObjectPtr fields(PyObject_GetAttrString(o, "_fields"));
int is_instance =
PyObject_IsInstance(fields.get(), GetCollectionsSequenceType().ptr());
if (is_instance == 0) {
Py_RETURN_FALSE;
} else if (is_instance == -1) {
return nullptr;
}
PyObjectPtr seq(PySequence_Fast(fields.get(), ""));
const Py_ssize_t s = PySequence_Fast_GET_SIZE(seq.get());
for (Py_ssize_t i = 0; i < s; ++i) {
// PySequence_Fast_GET_ITEM returns borrowed ref
PyObject* elem = PySequence_Fast_GET_ITEM(seq.get(), i);
if (!IsString(elem)) {
Py_RETURN_FALSE;
}
}
Py_RETURN_TRUE;
}
PyObject* SameNamedtuples(PyObject* o1, PyObject* o2) {
PyObject* f1 = PyObject_GetAttrString(o1, "_fields");
PyObject* f2 = PyObject_GetAttrString(o2, "_fields");
if (f1 == nullptr || f2 == nullptr) {
Py_XDECREF(f1);
Py_XDECREF(f2);
PyErr_SetString(
PyExc_RuntimeError,
"Expected namedtuple-like objects (that have _fields attr)");
return nullptr;
}
if (PyObject_RichCompareBool(f1, f2, Py_NE)) {
Py_RETURN_FALSE;
}
if (GetClassName(o1).compare(GetClassName(o2)) == 0) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
}
void AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types) {
std::string error_msg;
bool is_type_error = false;
AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error);
if (PyErr_Occurred()) {
// Don't hide Python exceptions while checking (e.g. errors fetching keys
// from custom mappings).
return;
}
if (!error_msg.empty()) {
PyErr_SetString(
is_type_error ? PyExc_TypeError : PyExc_ValueError,
absl::StrCat(
"The two structures don't have the same nested structure.\n\n",
"First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
.c_str());
}
}
ValueIteratorPtr GetValueIterator(PyObject* nested) {
if (PyDict_Check(nested)) {
return absl::make_unique(nested);
} else if (IsMappingHelper(nested)) {
return absl::make_unique(nested);
} else if (IsAttrsHelper(nested)) {
return absl::make_unique(nested);
} else {
return absl::make_unique(nested);
}
}
namespace {
inline py::object pyo_or_throw(PyObject* ptr) {
if (PyErr_Occurred() || ptr == nullptr) {
throw py::error_already_set();
}
return py::reinterpret_steal(ptr);
}
PYBIND11_MODULE(_tree, m) {
// Resolve `wrapt.ObjectProxy` at import time to avoid doing
// imports during function calls.
tree::GetWraptObjectProxyType();
m.def("assert_same_structure",
[](py::handle& o1, py::handle& o2, bool check_types) {
tree::AssertSameStructure(o1.ptr(), o2.ptr(), check_types);
if (PyErr_Occurred()) {
throw py::error_already_set();
}
});
m.def("is_sequence", [](py::handle& o) {
bool result = tree::IsSequence(o.ptr());
if (PyErr_Occurred()) {
throw py::error_already_set();
}
return result;
});
m.def("is_namedtuple", [](py::handle& o, bool strict) {
return pyo_or_throw(tree::IsNamedtuple(o.ptr(), strict));
});
m.def("is_attrs", [](py::handle& o) {
bool result = tree::IsAttrs(o.ptr());
if (PyErr_Occurred()) {
throw py::error_already_set();
}
return result;
});
m.def("same_namedtuples", [](py::handle& o1, py::handle& o2) {
return pyo_or_throw(tree::SameNamedtuples(o1.ptr(), o2.ptr()));
});
m.def("flatten", [](py::handle& nested) {
return pyo_or_throw(tree::Flatten(nested.ptr()));
});
}
} // namespace
} // namespace tree
tree-0.1.9/tree/tree.h 0000664 0000000 0000000 00000011273 14746732372 0014552 0 ustar 00root root 0000000 0000000 /* Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TREE_H_
#define TREE_H_
#include
#include
namespace tree {
// Returns a true if its input is a collections.Sequence (except strings).
//
// Args:
// seq: an input sequence.
//
// Returns:
// True if the sequence is a not a string and is a collections.Sequence or a
// dict.
bool IsSequence(PyObject* o);
// Returns Py_True iff `instance` should be considered a `namedtuple`.
//
// Args:
// instance: An instance of a Python object.
// strict: If True, `instance` is considered to be a `namedtuple` only if
// it is a "plain" namedtuple. For instance, a class inheriting
// from a `namedtuple` will be considered to be a `namedtuple`
// iff `strict=False`.
//
// Returns:
// True if `instance` is a `namedtuple`.
PyObject* IsNamedtuple(PyObject* o, bool strict);
// Returns a true if its input is an instance of an attr.s decorated class.
//
// Args:
// o: the input to be checked.
//
// Returns:
// True if the object is an instance of an attr.s decorated class.
bool IsAttrs(PyObject* o);
// Returns Py_True iff the two namedtuples have the same name and fields.
// Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have
// '_fields' attribute).
PyObject* SameNamedtuples(PyObject* o1, PyObject* o2);
// Asserts that two structures are nested in the same way.
//
// Note that namedtuples with identical name and fields are always considered
// to have the same shallow structure (even with `check_types=True`).
// For intance, this code will print `True`:
//
// ```python
// def nt(a, b):
// return collections.namedtuple('foo', 'a b')(a, b)
// print(assert_same_structure(nt(0, 1), nt(2, 3)))
// ```
//
// Args:
// nest1: an arbitrarily nested structure.
// nest2: an arbitrarily nested structure.
// check_types: if `true`, types of sequences are checked as
// well, including the keys of dictionaries. If set to `false`, for example
// a list and a tuple of objects will look the same if they have the same
// size. Note that namedtuples with identical name and fields are always
// considered to have the same shallow structure.
//
// Raises:
// ValueError: If the two structures do not have the same number of elements or
// if the two structures are not nested in the same way.
// TypeError: If the two structures differ in the type of sequence in any of
// their substructures. Only possible if `check_types` is `True`.
void AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types);
//
// Returns a flat list from a given nested structure.
//
// If `nest` is not a sequence, tuple, or dict, then returns a single-element
// list: `[nest]`.
//
// In the case of dict instances, the sequence consists of the values, sorted by
// key to ensure deterministic behavior. This is true also for `OrderedDict`
// instances: their sequence order is ignored, the sorting order of keys is
// used instead. The same convention is followed in `pack_sequence_as`. This
// correctly repacks dicts and `OrderedDict`s after they have been flattened,
// and also allows flattening an `OrderedDict` and then repacking it back using
// a corresponding plain dict, or vice-versa.
// Dictionaries with non-sortable keys cannot be flattened.
//
// Args:
// nest: an arbitrarily nested structure or a scalar object. Note, numpy
// arrays are considered scalars.
//
// Returns:
// A Python list, the flattened version of the input.
// On error, returns nullptr
//
// Raises:
// TypeError: The nest is or contains a dict with non-sortable keys.
PyObject* Flatten(PyObject* nested);
struct DecrementsPyRefcount {
void operator()(PyObject* p) const { Py_DECREF(p); }
};
// ValueIterator interface
class ValueIterator {
public:
virtual ~ValueIterator() {}
virtual std::unique_ptr next() = 0;
bool valid() const { return is_valid_; }
protected:
void invalidate() { is_valid_ = false; }
private:
bool is_valid_ = true;
};
std::unique_ptr GetValueIterator(PyObject* nested);
} // namespace tree
#endif // TREE_H_
tree-0.1.9/tree/tree_benchmark.py 0000664 0000000 0000000 00000003704 14746732372 0016765 0 ustar 00root root 0000000 0000000 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Benchmarks for utilities working with arbitrarily nested structures."""
import collections
import timeit
import tree
TIME_UNITS = [
(1, "s"),
(10**-3, "ms"),
(10**-6, "us"),
(10**-9, "ns"),
]
def format_time(time):
for d, unit in TIME_UNITS:
if time > d:
return "{:.2f}{}".format(time / d, unit)
def run_benchmark(benchmark_fn, num_iters):
times = timeit.repeat(benchmark_fn, repeat=2, number=num_iters)
return times[-1] / num_iters # Discard the first half for "warmup".
def map_to_list(func, *args):
return list(map(func, *args))
def benchmark_map(map_fn, structure):
def benchmark_fn():
return map_fn(lambda v: v, structure)
return benchmark_fn
BENCHMARKS = collections.OrderedDict([
("tree_map_1", benchmark_map(tree.map_structure, [0])),
("tree_map_8", benchmark_map(tree.map_structure, [0] * 8)),
("tree_map_64", benchmark_map(tree.map_structure, [0] * 64)),
("builtin_map_1", benchmark_map(map_to_list, [0])),
("builtin_map_8", benchmark_map(map_to_list, [0] * 8)),
("builtin_map_64", benchmark_map(map_to_list, [0] * 64)),
])
def main():
for name, benchmark_fn in BENCHMARKS.items():
print(name, format_time(run_benchmark(benchmark_fn, num_iters=1000)))
if __name__ == "__main__":
main()
tree-0.1.9/tree/tree_test.py 0000664 0000000 0000000 00000135722 14746732372 0016020 0 ustar 00root root 0000000 0000000 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for utilities working with arbitrarily nested structures."""
import collections
import doctest
import types
from typing import Any, Iterator, Mapping
import unittest
from absl.testing import parameterized
import attr
import numpy as np
import tree
import wrapt
STRUCTURE1 = (((1, 2), 3), 4, (5, 6))
STRUCTURE2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
STRUCTURE_DIFFERENT_NUM_ELEMENTS = ("spam", "eggs")
STRUCTURE_DIFFERENT_NESTING = (((1, 2), 3), 4, 5, (6,))
class DoctestTest(parameterized.TestCase):
def testDoctest(self):
extraglobs = {
"collections": collections,
"tree": tree,
}
num_failed, num_attempted = doctest.testmod(
tree, extraglobs=extraglobs, optionflags=doctest.ELLIPSIS)
self.assertGreater(num_attempted, 0, "No doctests found.")
self.assertEqual(num_failed, 0, "{} doctests failed".format(num_failed))
class NestTest(parameterized.TestCase):
def assertAllEquals(self, a, b):
self.assertTrue((np.asarray(a) == b).all())
def testAttrsFlattenAndUnflatten(self):
class BadAttr(object):
"""Class that has a non-iterable __attrs_attrs__."""
__attrs_attrs__ = None
@attr.s
class SampleAttr(object):
field1 = attr.ib()
field2 = attr.ib()
field_values = [1, 2]
sample_attr = SampleAttr(*field_values)
self.assertFalse(tree._is_attrs(field_values))
self.assertTrue(tree._is_attrs(sample_attr))
flat = tree.flatten(sample_attr)
self.assertEqual(field_values, flat)
restructured_from_flat = tree.unflatten_as(sample_attr, flat)
self.assertIsInstance(restructured_from_flat, SampleAttr)
self.assertEqual(restructured_from_flat, sample_attr)
# Check that flatten fails if attributes are not iterable
with self.assertRaisesRegex(TypeError, "object is not iterable"):
flat = tree.flatten(BadAttr())
@parameterized.parameters([
(1, 2, 3),
({"B": 10, "A": 20}, [1, 2], 3),
((1, 2), [3, 4], 5),
(collections.namedtuple("Point", ["x", "y"])(1, 2), 3, 4),
wrapt.ObjectProxy(
(collections.namedtuple("Point", ["x", "y"])(1, 2), 3, 4))
])
def testAttrsMapStructure(self, *field_values):
@attr.s
class SampleAttr(object):
field3 = attr.ib()
field1 = attr.ib()
field2 = attr.ib()
structure = SampleAttr(*field_values)
new_structure = tree.map_structure(lambda x: x, structure)
self.assertEqual(structure, new_structure)
def testFlattenAndUnflatten(self):
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
self.assertEqual(tree.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
self.assertEqual(
tree.unflatten_as(structure, flat),
(("a", "b"), "c", ("d", "e", ("f", "g"), "h")))
point = collections.namedtuple("Point", ["x", "y"])
structure = (point(x=4, y=2), ((point(x=1, y=0),),))
flat = [4, 2, 1, 0]
self.assertEqual(tree.flatten(structure), flat)
restructured_from_flat = tree.unflatten_as(structure, flat)
self.assertEqual(restructured_from_flat, structure)
self.assertEqual(restructured_from_flat[0].x, 4)
self.assertEqual(restructured_from_flat[0].y, 2)
self.assertEqual(restructured_from_flat[1][0][0].x, 1)
self.assertEqual(restructured_from_flat[1][0][0].y, 0)
self.assertEqual([5], tree.flatten(5))
self.assertEqual([np.array([5])], tree.flatten(np.array([5])))
self.assertEqual("a", tree.unflatten_as(5, ["a"]))
self.assertEqual(
np.array([5]), tree.unflatten_as("scalar", [np.array([5])]))
with self.assertRaisesRegex(ValueError, "Structure is a scalar"):
tree.unflatten_as("scalar", [4, 5])
with self.assertRaisesRegex(TypeError, "flat_sequence"):
tree.unflatten_as([4, 5], "bad_sequence")
with self.assertRaises(ValueError):
tree.unflatten_as([5, 6, [7, 8]], ["a", "b", "c"])
def testFlattenDictOrder(self):
ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
plain = {"d": 3, "b": 1, "a": 0, "c": 2}
ordered_flat = tree.flatten(ordered)
plain_flat = tree.flatten(plain)
self.assertEqual([0, 1, 2, 3], ordered_flat)
self.assertEqual([0, 1, 2, 3], plain_flat)
def testUnflattenDictOrder(self):
ordered = collections.OrderedDict([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
plain = {"d": 0, "b": 0, "a": 0, "c": 0}
seq = [0, 1, 2, 3]
ordered_reconstruction = tree.unflatten_as(ordered, seq)
plain_reconstruction = tree.unflatten_as(plain, seq)
self.assertEqual(
collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
ordered_reconstruction)
self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
def testFlattenAndUnflatten_withDicts(self):
# A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
named_tuple = collections.namedtuple("A", ("b", "c"))
mess = [
"z",
named_tuple(3, 4),
{
"c": [
1,
collections.OrderedDict([
("b", 3),
("a", 2),
]),
],
"b": 5
},
17
]
flattened = tree.flatten(mess)
self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17])
structure_of_mess = [
14,
named_tuple("a", True),
{
"c": [
0,
collections.OrderedDict([
("b", 9),
("a", 8),
]),
],
"b": 3
},
"hi everybody",
]
self.assertEqual(mess, tree.unflatten_as(structure_of_mess, flattened))
# Check also that the OrderedDict was created, with the correct key order.
unflattened_ordered_dict = tree.unflatten_as(
structure_of_mess, flattened)[2]["c"][1]
self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
def testFlatten_numpyIsNotFlattened(self):
structure = np.array([1, 2, 3])
flattened = tree.flatten(structure)
self.assertLen(flattened, 1)
def testFlatten_stringIsNotFlattened(self):
structure = "lots of letters"
flattened = tree.flatten(structure)
self.assertLen(flattened, 1)
self.assertEqual(structure, tree.unflatten_as("goodbye", flattened))
def testFlatten_bytearrayIsNotFlattened(self):
structure = bytearray("bytes in an array", "ascii")
flattened = tree.flatten(structure)
self.assertLen(flattened, 1)
self.assertEqual(flattened, [structure])
self.assertEqual(structure,
tree.unflatten_as(bytearray("hello", "ascii"), flattened))
def testUnflattenSequenceAs_notIterableError(self):
with self.assertRaisesRegex(TypeError, "flat_sequence must be a sequence"):
tree.unflatten_as("hi", "bye")
def testUnflattenSequenceAs_wrongLengthsError(self):
with self.assertRaisesRegex(
ValueError,
"Structure had 2 elements, but flat_sequence had 3 elements."):
tree.unflatten_as(["hello", "world"], ["and", "goodbye", "again"])
def testUnflattenSequenceAs_defaultdict(self):
structure = collections.defaultdict(
list, [("a", [None]), ("b", [None, None])])
sequence = [1, 2, 3]
expected = collections.defaultdict(
list, [("a", [1]), ("b", [2, 3])])
self.assertEqual(expected, tree.unflatten_as(structure, sequence))
def testIsSequence(self):
self.assertFalse(tree.is_nested("1234"))
self.assertFalse(tree.is_nested(b"1234"))
self.assertFalse(tree.is_nested(u"1234"))
self.assertFalse(tree.is_nested(bytearray("1234", "ascii")))
self.assertTrue(tree.is_nested([1, 3, [4, 5]]))
self.assertTrue(tree.is_nested(((7, 8), (5, 6))))
self.assertTrue(tree.is_nested([]))
self.assertTrue(tree.is_nested({"a": 1, "b": 2}))
self.assertFalse(tree.is_nested(set([1, 2])))
ones = np.ones([2, 3])
self.assertFalse(tree.is_nested(ones))
self.assertFalse(tree.is_nested(np.tanh(ones)))
self.assertFalse(tree.is_nested(np.ones((4, 5))))
# pylint does not correctly recognize these as class names and
# suggests to use variable style under_score naming.
# pylint: disable=invalid-name
Named0ab = collections.namedtuple("named_0", ("a", "b"))
Named1ab = collections.namedtuple("named_1", ("a", "b"))
SameNameab = collections.namedtuple("same_name", ("a", "b"))
SameNameab2 = collections.namedtuple("same_name", ("a", "b"))
SameNamexy = collections.namedtuple("same_name", ("x", "y"))
SameName1xy = collections.namedtuple("same_name_1", ("x", "y"))
SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y"))
NotSameName = collections.namedtuple("not_same_name", ("a", "b"))
# pylint: enable=invalid-name
class SameNamedType1(SameNameab):
pass
# pylint: disable=g-error-prone-assert-raises
def testAssertSameStructure(self):
tree.assert_same_structure(STRUCTURE1, STRUCTURE2)
tree.assert_same_structure("abc", 1.0)
tree.assert_same_structure(b"abc", 1.0)
tree.assert_same_structure(u"abc", 1.0)
tree.assert_same_structure(bytearray("abc", "ascii"), 1.0)
tree.assert_same_structure("abc", np.array([0, 1]))
def testAssertSameStructure_differentNumElements(self):
with self.assertRaisesRegex(
ValueError,
("The two structures don't have the same nested structure\\.\n\n"
"First structure:.*?\n\n"
"Second structure:.*\n\n"
"More specifically: Substructure "
r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while '
'substructure "type=str str=spam" is not\n'
"Entire first structure:\n"
r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n"
"Entire second structure:\n"
r"\(\., \.\)")):
tree.assert_same_structure(STRUCTURE1, STRUCTURE_DIFFERENT_NUM_ELEMENTS)
def testAssertSameStructure_listVsNdArray(self):
with self.assertRaisesRegex(
ValueError,
("The two structures don't have the same nested structure\\.\n\n"
"First structure:.*?\n\n"
"Second structure:.*\n\n"
r'More specifically: Substructure "type=list str=\[0, 1\]" '
r'is a sequence, while substructure "type=ndarray str=\[0 1\]" '
"is not")):
tree.assert_same_structure([0, 1], np.array([0, 1]))
def testAssertSameStructure_intVsList(self):
with self.assertRaisesRegex(
ValueError,
("The two structures don't have the same nested structure\\.\n\n"
"First structure:.*?\n\n"
"Second structure:.*\n\n"
r'More specifically: Substructure "type=list str=\[0, 1\]" '
'is a sequence, while substructure "type=int str=0" '
"is not")):
tree.assert_same_structure(0, [0, 1])
def testAssertSameStructure_tupleVsList(self):
self.assertRaises(
TypeError, tree.assert_same_structure, (0, 1), [0, 1])
def testAssertSameStructure_differentNesting(self):
with self.assertRaisesRegex(
ValueError,
("don't have the same nested structure\\.\n\n"
"First structure: .*?\n\nSecond structure: ")):
tree.assert_same_structure(STRUCTURE1, STRUCTURE_DIFFERENT_NESTING)
def testAssertSameStructure_tupleVsNamedTuple(self):
self.assertRaises(TypeError, tree.assert_same_structure, (0, 1),
NestTest.Named0ab("a", "b"))
def testAssertSameStructure_sameNamedTupleDifferentContents(self):
tree.assert_same_structure(NestTest.Named0ab(3, 4),
NestTest.Named0ab("a", "b"))
def testAssertSameStructure_differentNamedTuples(self):
self.assertRaises(TypeError, tree.assert_same_structure,
NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4))
def testAssertSameStructure_sameNamedTupleDifferentStructuredContents(self):
with self.assertRaisesRegex(
ValueError,
("don't have the same nested structure\\.\n\n"
"First structure: .*?\n\nSecond structure: ")):
tree.assert_same_structure(NestTest.Named0ab(3, 4),
NestTest.Named0ab([3], 4))
def testAssertSameStructure_differentlyNestedLists(self):
with self.assertRaisesRegex(
ValueError,
("don't have the same nested structure\\.\n\n"
"First structure: .*?\n\nSecond structure: ")):
tree.assert_same_structure([[3], 4], [3, [4]])
def testAssertSameStructure_listStructureWithAndWithoutTypes(self):
structure1_list = [[[1, 2], 3], 4, [5, 6]]
with self.assertRaisesRegex(TypeError, "don't have the same sequence type"):
tree.assert_same_structure(STRUCTURE1, structure1_list)
tree.assert_same_structure(STRUCTURE1, STRUCTURE2, check_types=False)
tree.assert_same_structure(STRUCTURE1, structure1_list, check_types=False)
def testAssertSameStructure_dictionaryDifferentKeys(self):
with self.assertRaisesRegex(ValueError, "don't have the same set of keys"):
tree.assert_same_structure({"a": 1}, {"b": 1})
def testAssertSameStructure_sameNameNamedTuples(self):
tree.assert_same_structure(NestTest.SameNameab(0, 1),
NestTest.SameNameab2(2, 3))
def testAssertSameStructure_sameNameNamedTuplesNested(self):
# This assertion is expected to pass: two namedtuples with the same
# name and field names are considered to be identical.
tree.assert_same_structure(
NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2),
NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4))
def testAssertSameStructure_sameNameNamedTuplesDifferentStructure(self):
expected_message = "The two structures don't have the same.*"
with self.assertRaisesRegex(ValueError, expected_message):
tree.assert_same_structure(
NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)),
NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2))
def testAssertSameStructure_differentNameNamedStructures(self):
self.assertRaises(TypeError, tree.assert_same_structure,
NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3))
def testAssertSameStructure_sameNameDifferentFieldNames(self):
self.assertRaises(TypeError, tree.assert_same_structure,
NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3))
def testAssertSameStructure_classWrappingNamedTuple(self):
self.assertRaises(TypeError, tree.assert_same_structure,
NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3))
# pylint: enable=g-error-prone-assert-raises
def testMapStructure(self):
structure2 = (((7, 8), 9), 10, (11, 12))
structure1_plus1 = tree.map_structure(lambda x: x + 1, STRUCTURE1)
tree.assert_same_structure(STRUCTURE1, structure1_plus1)
self.assertAllEquals(
[2, 3, 4, 5, 6, 7],
tree.flatten(structure1_plus1))
structure1_plus_structure2 = tree.map_structure(
lambda x, y: x + y, STRUCTURE1, structure2)
self.assertEqual(
(((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
structure1_plus_structure2)
self.assertEqual(3, tree.map_structure(lambda x: x - 1, 4))
self.assertEqual(7, tree.map_structure(lambda x, y: x + y, 3, 4))
# Empty structures
self.assertEqual((), tree.map_structure(lambda x: x + 1, ()))
self.assertEqual([], tree.map_structure(lambda x: x + 1, []))
self.assertEqual({}, tree.map_structure(lambda x: x + 1, {}))
empty_nt = collections.namedtuple("empty_nt", "")
self.assertEqual(empty_nt(), tree.map_structure(lambda x: x + 1,
empty_nt()))
# This is checking actual equality of types, empty list != empty tuple
self.assertNotEqual((), tree.map_structure(lambda x: x + 1, []))
with self.assertRaisesRegex(TypeError, "callable"):
tree.map_structure("bad", structure1_plus1)
with self.assertRaisesRegex(ValueError, "at least one structure"):
tree.map_structure(lambda x: x)
with self.assertRaisesRegex(ValueError, "same number of elements"):
tree.map_structure(lambda x, y: None, (3, 4), (3, 4, 5))
with self.assertRaisesRegex(ValueError, "same nested structure"):
tree.map_structure(lambda x, y: None, 3, (3,))
with self.assertRaisesRegex(TypeError, "same sequence type"):
tree.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])
with self.assertRaisesRegex(ValueError, "same nested structure"):
tree.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
structure1_list = [[[1, 2], 3], 4, [5, 6]]
with self.assertRaisesRegex(TypeError, "same sequence type"):
tree.map_structure(lambda x, y: None, STRUCTURE1, structure1_list)
tree.map_structure(lambda x, y: None, STRUCTURE1, structure1_list,
check_types=False)
with self.assertRaisesRegex(ValueError, "same nested structure"):
tree.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
check_types=False)
with self.assertRaisesRegex(ValueError, "Only valid keyword argument.*foo"):
tree.map_structure(lambda x: None, STRUCTURE1, foo="a")
with self.assertRaisesRegex(ValueError, "Only valid keyword argument.*foo"):
tree.map_structure(lambda x: None, STRUCTURE1, check_types=False, foo="a")
def testMapStructureWithStrings(self):
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
inp_a = ab_tuple(a="foo", b=("bar", "baz"))
inp_b = ab_tuple(a=2, b=(1, 3))
out = tree.map_structure(lambda string, repeats: string * repeats,
inp_a,
inp_b)
self.assertEqual("foofoo", out.a)
self.assertEqual("bar", out.b[0])
self.assertEqual("bazbazbaz", out.b[1])
nt = ab_tuple(a=("something", "something_else"),
b="yet another thing")
rev_nt = tree.map_structure(lambda x: x[::-1], nt)
# Check the output is the correct structure, and all strings are reversed.
tree.assert_same_structure(nt, rev_nt)
self.assertEqual(nt.a[0][::-1], rev_nt.a[0])
self.assertEqual(nt.a[1][::-1], rev_nt.a[1])
self.assertEqual(nt.b[::-1], rev_nt.b)
def testAssertShallowStructure(self):
inp_ab = ["a", "b"]
inp_abc = ["a", "b", "c"]
with self.assertRaisesRegex(
ValueError,
tree._STRUCTURES_HAVE_MISMATCHING_LENGTHS.format(
input_length=len(inp_ab),
shallow_length=len(inp_abc))):
tree._assert_shallow_structure(inp_abc, inp_ab)
inp_ab1 = [(1, 1), (2, 2)]
inp_ab2 = [[1, 1], [2, 2]]
with self.assertRaisesWithLiteralMatch(
TypeError,
tree._STRUCTURES_HAVE_MISMATCHING_TYPES.format(
shallow_type=type(inp_ab2[0]),
input_type=type(inp_ab1[0]))):
tree._assert_shallow_structure(shallow_tree=inp_ab2, input_tree=inp_ab1)
tree._assert_shallow_structure(inp_ab2, inp_ab1, check_types=False)
inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
with self.assertRaisesWithLiteralMatch(
ValueError,
tree._SHALLOW_TREE_HAS_INVALID_KEYS.format(["d"])):
tree._assert_shallow_structure(inp_ab2, inp_ab1)
inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)])
tree._assert_shallow_structure(inp_ab, inp_ba)
# regression test for b/130633904
tree._assert_shallow_structure({0: "foo"}, ["foo"], check_types=False)
def testFlattenUpTo(self):
# Shallow tree ends at scalar.
input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
shallow_tree = [[True, True], [False, True]]
flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
self.assertEqual(flattened_shallow_tree, [True, True, False, True])
# Shallow tree ends at string.
input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree,
input_tree)
input_tree_flattened = tree.flatten(input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[("a", 1), ("b", 2), ("c", 3), ("d", 4)])
self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
# Make sure dicts are correctly flattened, yielding values, not keys.
input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[1, {"c": 2}, 3, (4, 5)])
# Namedtuples.
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
input_tree = ab_tuple(a=[0, 1], b=2)
shallow_tree = ab_tuple(a=0, b=1)
input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[[0, 1], 2])
# Attrs.
@attr.s
class ABAttr(object):
a = attr.ib()
b = attr.ib()
input_tree = ABAttr(a=[0, 1], b=2)
shallow_tree = ABAttr(a=0, b=1)
input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[[0, 1], 2])
# Nested dicts, OrderedDicts and namedtuples.
input_tree = collections.OrderedDict(
[("a", ab_tuple(a=[0, {"b": 1}], b=2)),
("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
shallow_tree = input_tree
input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})])
input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[ab_tuple(a=[0, {"b": 1}], b=2),
3,
collections.OrderedDict([("f", 4)])])
shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[ab_tuple(a=[0, {"b": 1}], b=2),
{"d": 3, "e": collections.OrderedDict([("f", 4)])}])
## Shallow non-list edge-case.
# Using iterable elements.
input_tree = ["input_tree"]
shallow_tree = "shallow_tree"
flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
input_tree = ["input_tree_0", "input_tree_1"]
shallow_tree = "shallow_tree"
flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
# Using non-iterable elements.
input_tree = [0]
shallow_tree = 9
flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
input_tree = [0, 1]
shallow_tree = 9
flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
## Both non-list edge-case.
# Using iterable elements.
input_tree = "input_tree"
shallow_tree = "shallow_tree"
flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
# Using non-iterable elements.
input_tree = 0
shallow_tree = 0
flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
## Input non-list edge-case.
# Using iterable elements.
input_tree = "input_tree"
shallow_tree = ["shallow_tree"]
with self.assertRaisesWithLiteralMatch(
TypeError,
tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))):
flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, shallow_tree)
input_tree = "input_tree"
shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
with self.assertRaisesWithLiteralMatch(
TypeError,
tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))):
flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, shallow_tree)
# Using non-iterable elements.
input_tree = 0
shallow_tree = [9]
with self.assertRaisesWithLiteralMatch(
TypeError,
tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))):
flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, shallow_tree)
input_tree = 0
shallow_tree = [9, 8]
with self.assertRaisesWithLiteralMatch(
TypeError,
tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))):
flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, shallow_tree)
def testByteStringsNotTreatedAsIterable(self):
structure = [u"unicode string", b"byte string"]
flattened_structure = tree.flatten_up_to(structure, structure)
self.assertEqual(structure, flattened_structure)
def testFlattenWithPathUpTo(self):
def get_paths_and_values(shallow_tree, input_tree):
path_value_pairs = tree.flatten_with_path_up_to(shallow_tree, input_tree)
paths = [p for p, _ in path_value_pairs]
values = [v for _, v in path_value_pairs]
return paths, values
# Shallow tree ends at scalar.
input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
shallow_tree = [[True, True], [False, True]]
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree_paths,
[(0, 0), (0, 1), (1, 0), (1, 1)])
self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
self.assertEqual(flattened_shallow_tree_paths,
[(0, 0), (0, 1), (1, 0), (1, 1)])
self.assertEqual(flattened_shallow_tree, [True, True, False, True])
# Shallow tree ends at string.
input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
(input_tree_flattened_as_shallow_tree_paths,
input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
input_tree)
input_tree_flattened_paths = [
p for p, _ in tree.flatten_with_path(input_tree)
]
input_tree_flattened = tree.flatten(input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
[(0, 0), (0, 1, 0), (0, 1, 1, 0), (0, 1, 1, 1, 0)])
self.assertEqual(input_tree_flattened_as_shallow_tree,
[("a", 1), ("b", 2), ("c", 3), ("d", 4)])
self.assertEqual(input_tree_flattened_paths,
[(0, 0, 0), (0, 0, 1),
(0, 1, 0, 0), (0, 1, 0, 1),
(0, 1, 1, 0, 0), (0, 1, 1, 0, 1),
(0, 1, 1, 1, 0, 0), (0, 1, 1, 1, 0, 1)])
self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
# Make sure dicts are correctly flattened, yielding values, not keys.
input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
(input_tree_flattened_as_shallow_tree_paths,
input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
[("a",), ("b",), ("d", 0), ("d", 1)])
self.assertEqual(input_tree_flattened_as_shallow_tree,
[1, {"c": 2}, 3, (4, 5)])
# Namedtuples.
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
input_tree = ab_tuple(a=[0, 1], b=2)
shallow_tree = ab_tuple(a=0, b=1)
(input_tree_flattened_as_shallow_tree_paths,
input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
[("a",), ("b",)])
self.assertEqual(input_tree_flattened_as_shallow_tree,
[[0, 1], 2])
# Nested dicts, OrderedDicts and namedtuples.
input_tree = collections.OrderedDict(
[("a", ab_tuple(a=[0, {"b": 1}], b=2)),
("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
shallow_tree = input_tree
(input_tree_flattened_as_shallow_tree_paths,
input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
[("a", "a", 0),
("a", "a", 1, "b"),
("a", "b"),
("c", "d"),
("c", "e", "f")])
self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})])
(input_tree_flattened_as_shallow_tree_paths,
input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
[("a",),
("c", "d"),
("c", "e")])
self.assertEqual(input_tree_flattened_as_shallow_tree,
[ab_tuple(a=[0, {"b": 1}], b=2),
3,
collections.OrderedDict([("f", 4)])])
shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
(input_tree_flattened_as_shallow_tree_paths,
input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
[("a",), ("c",)])
self.assertEqual(input_tree_flattened_as_shallow_tree,
[ab_tuple(a=[0, {"b": 1}], b=2),
{"d": 3, "e": collections.OrderedDict([("f", 4)])}])
## Shallow non-list edge-case.
# Using iterable elements.
input_tree = ["input_tree"]
shallow_tree = "shallow_tree"
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree_paths, [()])
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree_paths, [()])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
input_tree = ["input_tree_0", "input_tree_1"]
shallow_tree = "shallow_tree"
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree_paths, [()])
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree_paths, [()])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
# Test case where len(shallow_tree) < len(input_tree)
input_tree = {"a": "A", "b": "B", "c": "C"}
shallow_tree = {"a": 1, "c": 2}
# Using non-iterable elements.
input_tree = [0]
shallow_tree = 9
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree_paths, [()])
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree_paths, [()])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
input_tree = [0, 1]
shallow_tree = 9
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree_paths, [()])
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree_paths, [()])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
## Both non-list edge-case.
# Using iterable elements.
input_tree = "input_tree"
shallow_tree = "shallow_tree"
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree_paths, [()])
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree_paths, [()])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
# Using non-iterable elements.
input_tree = 0
shallow_tree = 0
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree_paths, [()])
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree_paths, [()])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
## Input non-list edge-case.
# Using iterable elements.
input_tree = "input_tree"
shallow_tree = ["shallow_tree"]
with self.assertRaisesWithLiteralMatch(
TypeError,
tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format(
path=[], input_type=type(input_tree))):
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree_paths, [(0,)])
self.assertEqual(flattened_shallow_tree, shallow_tree)
input_tree = "input_tree"
shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
with self.assertRaisesWithLiteralMatch(
TypeError,
tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format(
path=[], input_type=type(input_tree))):
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree_paths, [(0,), (1,)])
self.assertEqual(flattened_shallow_tree, shallow_tree)
# Using non-iterable elements.
input_tree = 0
shallow_tree = [9]
with self.assertRaisesWithLiteralMatch(
TypeError,
tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format(
path=[], input_type=type(input_tree))):
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree_paths, [(0,)])
self.assertEqual(flattened_shallow_tree, shallow_tree)
input_tree = 0
shallow_tree = [9, 8]
with self.assertRaisesWithLiteralMatch(
TypeError,
tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format(
path=[], input_type=type(input_tree))):
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree_paths, [(0,), (1,)])
self.assertEqual(flattened_shallow_tree, shallow_tree)
# Test that error messages include paths.
input_tree = {"a": {"b": {0, 1}}}
structure = {"a": {"b": [0, 1]}}
with self.assertRaisesWithLiteralMatch(
TypeError,
tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format(
path=["a", "b"], input_type=type(input_tree["a"]["b"]))):
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(structure, input_tree)
(flattened_tree_paths,
flattened_tree) = get_paths_and_values(structure, structure)
self.assertEqual(flattened_tree_paths, [("a", "b", 0,), ("a", "b", 1,)])
self.assertEqual(flattened_tree, structure["a"]["b"])
def testMapStructureUpTo(self):
# Named tuples.
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
op_tuple = collections.namedtuple("op_tuple", "add, mul")
inp_val = ab_tuple(a=2, b=3)
inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
out = tree.map_structure_up_to(
inp_val,
lambda val, ops: (val + ops.add) * ops.mul,
inp_val,
inp_ops,
check_types=False)
self.assertEqual(out.a, 6)
self.assertEqual(out.b, 15)
# Lists.
data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
name_list = ["evens", ["odds", "primes"]]
out = tree.map_structure_up_to(
name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
name_list, data_list)
self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]])
# We cannot define namedtuples within @parameterized argument lists.
# pylint: disable=invalid-name
Foo = collections.namedtuple("Foo", ["a", "b"])
Bar = collections.namedtuple("Bar", ["c", "d"])
# pylint: enable=invalid-name
@parameterized.parameters([
dict(inputs=[], expected=[]),
dict(inputs=[23, "42"], expected=[((0,), 23), ((1,), "42")]),
dict(inputs=[[[[108]]]], expected=[((0, 0, 0, 0), 108)]),
dict(inputs=Foo(a=3, b=Bar(c=23, d=42)),
expected=[(("a",), 3), (("b", "c"), 23), (("b", "d"), 42)]),
dict(inputs=Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="thing")),
expected=[(("a", "c"), 23), (("a", "d"), 42), (("b", "c"), 0),
(("b", "d"), "thing")]),
dict(inputs=Bar(c=42, d=43),
expected=[(("c",), 42), (("d",), 43)]),
dict(inputs=Bar(c=[42], d=43),
expected=[(("c", 0), 42), (("d",), 43)]),
dict(inputs=wrapt.ObjectProxy(Bar(c=[42], d=43)),
expected=[(("c", 0), 42), (("d",), 43)]),
])
def testFlattenWithPath(self, inputs, expected):
self.assertEqual(tree.flatten_with_path(inputs), expected)
@parameterized.named_parameters([
dict(testcase_name="Tuples", s1=(1, 2), s2=(3, 4),
check_types=True, expected=(((0,), 4), ((1,), 6))),
dict(testcase_name="Dicts", s1={"a": 1, "b": 2}, s2={"b": 4, "a": 3},
check_types=True, expected={"a": (("a",), 4), "b": (("b",), 6)}),
dict(testcase_name="Mixed", s1=(1, 2), s2=[3, 4],
check_types=False, expected=(((0,), 4), ((1,), 6))),
dict(testcase_name="Nested",
s1={"a": [2, 3], "b": [1, 2, 3]},
s2={"b": [5, 6, 7], "a": [8, 9]},
check_types=True,
expected={"a": [(("a", 0), 10), (("a", 1), 12)],
"b": [(("b", 0), 6), (("b", 1), 8), (("b", 2), 10)]}),
])
def testMapWithPathCompatibleStructures(self, s1, s2, check_types, expected):
def path_and_sum(path, *values):
return path, sum(values)
result = tree.map_structure_with_path(
path_and_sum, s1, s2, check_types=check_types)
self.assertEqual(expected, result)
@parameterized.named_parameters([
dict(testcase_name="Tuples", s1=(1, 2, 3), s2=(4, 5),
error_type=ValueError),
dict(testcase_name="Dicts", s1={"a": 1}, s2={"b": 2},
error_type=ValueError),
dict(testcase_name="Nested",
s1={"a": [2, 3, 4], "b": [1, 3]},
s2={"b": [5, 6], "a": [8, 9]},
error_type=ValueError)
])
def testMapWithPathIncompatibleStructures(self, s1, s2, error_type):
with self.assertRaises(error_type):
tree.map_structure_with_path(lambda path, *s: 0, s1, s2)
def testMappingProxyType(self):
structure = types.MappingProxyType({"a": 1, "b": (2, 3)})
expected = types.MappingProxyType({"a": 4, "b": (5, 6)})
self.assertEqual(tree.flatten(structure), [1, 2, 3])
self.assertEqual(tree.unflatten_as(structure, [4, 5, 6]), expected)
self.assertEqual(tree.map_structure(lambda v: v + 3, structure), expected)
def testTraverseListsToTuples(self):
structure = [(1, 2), [3], {"a": [4]}]
self.assertEqual(
((1, 2), (3,), {"a": (4,)}),
tree.traverse(
lambda x: tuple(x) if isinstance(x, list) else x,
structure,
top_down=False))
def testTraverseEarlyTermination(self):
structure = [(1, [2]), [3, (4, 5, 6)]]
visited = []
def visit(x):
visited.append(x)
return "X" if isinstance(x, tuple) and len(x) > 2 else None
output = tree.traverse(visit, structure)
self.assertEqual([(1, [2]), [3, "X"]], output)
self.assertEqual(
[[(1, [2]), [3, (4, 5, 6)]],
(1, [2]), 1, [2], 2, [3, (4, 5, 6)], 3, (4, 5, 6)],
visited)
def testMapStructureAcrossSubtreesDict(self):
shallow = {"a": 1, "b": {"c": 2}}
deep1 = {"a": 2, "b": {"c": 3, "d": 2}, "e": 4}
deep2 = {"a": 3, "b": {"c": 2, "d": 3}, "e": 1}
summed = tree.map_structure_up_to(
shallow, lambda *args: sum(args), deep1, deep2)
expected = {"a": 5, "b": {"c": 5}}
self.assertEqual(summed, expected)
concatenated = tree.map_structure_up_to(
shallow, lambda *args: args, deep1, deep2)
expected = {"a": (2, 3), "b": {"c": (3, 2)}}
self.assertEqual(concatenated, expected)
def testMapStructureAcrossSubtreesNoneValues(self):
shallow = [1, [None]]
deep1 = [1, [2, 3]]
deep2 = [2, [3, 4]]
summed = tree.map_structure_up_to(
shallow, lambda *args: sum(args), deep1, deep2)
expected = [3, [5]]
self.assertEqual(summed, expected)
def testMapStructureAcrossSubtreesList(self):
shallow = [1, [1]]
deep1 = [1, [2, 3]]
deep2 = [2, [3, 4]]
summed = tree.map_structure_up_to(
shallow, lambda *args: sum(args), deep1, deep2)
expected = [3, [5]]
self.assertEqual(summed, expected)
def testMapStructureAcrossSubtreesTuple(self):
shallow = (1, (1,))
deep1 = (1, (2, 3))
deep2 = (2, (3, 4))
summed = tree.map_structure_up_to(
shallow, lambda *args: sum(args), deep1, deep2)
expected = (3, (5,))
self.assertEqual(summed, expected)
def testMapStructureAcrossSubtreesNamedTuple(self):
Foo = collections.namedtuple("Foo", ["x", "y"])
Bar = collections.namedtuple("Bar", ["x"])
shallow = Bar(1)
deep1 = Foo(1, (1, 0))
deep2 = Foo(2, (2, 0))
summed = tree.map_structure_up_to(
shallow, lambda *args: sum(args), deep1, deep2)
expected = Bar(3)
self.assertEqual(summed, expected)
def testMapStructureAcrossSubtreesListTuple(self):
# Tuples and lists can be used interchangeably between shallow structure
# and input structures. Output takes on type of the shallow structure
shallow = [1, (1,)]
deep1 = [1, [2, 3]]
deep2 = [2, [3, 4]]
summed = tree.map_structure_up_to(shallow, lambda *args: sum(args), deep1,
deep2)
expected = [3, (5,)]
self.assertEqual(summed, expected)
shallow = [1, [1]]
deep1 = [1, (2, 3)]
deep2 = [2, (3, 4)]
summed = tree.map_structure_up_to(shallow, lambda *args: sum(args), deep1,
deep2)
expected = [3, [5]]
self.assertEqual(summed, expected)
def testNoneNodeIncluded(self):
structure = ((1, None))
self.assertEqual(tree.flatten(structure), [1, None])
def testCustomClassMapWithPath(self):
class ExampleClass(Mapping[Any, Any]):
"""Small example custom class."""
def __init__(self, *args, **kwargs):
self._mapping = dict(*args, **kwargs)
def __getitem__(self, k: Any) -> Any:
return self._mapping[k]
def __len__(self) -> int:
return len(self._mapping)
def __iter__(self) -> Iterator[Any]:
return iter(self._mapping)
def mapper(path, value):
full_path = "/".join(path)
return f"{full_path}_{value}"
test_input = ExampleClass({"first": 1, "nested": {"second": 2, "third": 3}})
output = tree.map_structure_with_path(mapper, test_input)
expected = ExampleClass({
"first": "first_1",
"nested": {
"second": "nested/second_2",
"third": "nested/third_3"
}
})
self.assertEqual(output, expected)
if __name__ == "__main__":
unittest.main()