pax_global_header00006660000000000000000000000064147467323720014531gustar00rootroot0000000000000052 comment=3e79613493826e1eab59cfcfa2e4a54b56900c77 tree-0.1.9/000077500000000000000000000000001474673237200124775ustar00rootroot00000000000000tree-0.1.9/.github/000077500000000000000000000000001474673237200140375ustar00rootroot00000000000000tree-0.1.9/.github/workflows/000077500000000000000000000000001474673237200160745ustar00rootroot00000000000000tree-0.1.9/.github/workflows/build.yml000066400000000000000000000053701474673237200177230ustar00rootroot00000000000000name: build on: push: branches: [master] pull_request: branches: [master] release: types: [created] workflow_dispatch: jobs: sdist: name: sdist runs-on: ubuntu-24.04 steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: '3.11' - name: Create sdist run: | python -m pip install --upgrade pip setuptools python setup.py sdist shell: bash - name: List output directory run: ls -lh dist/dm_tree*.tar.gz shell: bash - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 if: ${{ github.event_name == 'workflow_dispatch' || (github.event_name == 'release' && github.event.action == 'created') }} with: name: dm-tree-sdist path: dist/dm_tree*.tar.gz bdist-wheel: name: Build wheels on ${{ matrix.os }} runs-on: ${{ matrix.os }} strategy: matrix: os: [ubuntu-24.04, macos-14, windows-2022] # latest steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: true - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: "3.11" - name: Set up QEMU if: runner.os == 'Linux' uses: docker/setup-qemu-action@53851d14592bedcffcf25ea515637cff71ef929a # v3.3.0 with: platforms: all # This should be temporary # xref https://github.com/docker/setup-qemu-action/issues/188 # xref https://github.com/tonistiigi/binfmt/issues/215 image: tonistiigi/binfmt:qemu-v8.1.5 - name: Install cibuildwheel run: python -m pip install cibuildwheel==2.22.0 - name: Build wheels run: python -m cibuildwheel --output-dir wheelhouse env: CIBW_ARCHS_LINUX: auto aarch64 CIBW_ARCHS_MACOS: universal2 CIBW_BUILD: "cp310-* cp311-* cp312-* cp313-* cp313t-*" CIBW_BUILD_VERBOSITY: 1 CIBW_FREE_THREADED_SUPPORT: True CIBW_PRERELEASE_PYTHONS: True CIBW_SKIP: "*musllinux* *i686* *win32* *t-win*" CIBW_TEST_COMMAND: pytest --pyargs tree CIBW_TEST_REQUIRES: pytest MAKEFLAGS: "-j$(nproc)" - uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 if: ${{ github.event_name == 'workflow_dispatch' || (github.event_name == 'release' && github.event.action == 'created') }} with: name: dm-tree-bdist-wheel-${{ matrix.os }}-${{ strategy.job-index }} path: wheelhouse/*.whl tree-0.1.9/CONTRIBUTING.md000066400000000000000000000021151474673237200147270ustar00rootroot00000000000000# How to Contribute We'd love to accept your patches and contributions to this project. There are just a few small guidelines you need to follow. ## Contributor License Agreement Contributions to this project must be accompanied by a Contributor License Agreement. You (or your employer) retain the copyright to your contribution; this simply gives us permission to use and redistribute your contributions as part of the project. Head over to to see your current agreements on file or to sign a new one. You generally only need to submit a CLA once, so if you've already submitted one (even if it was for a different project), you probably don't need to do it again. ## Code reviews All submissions, including submissions by project members, require review. We use GitHub pull requests for this purpose. Consult [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more information on using pull requests. ## Community Guidelines This project follows [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). tree-0.1.9/LICENSE000066400000000000000000000261361474673237200135140ustar00rootroot00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. tree-0.1.9/MANIFEST.in000066400000000000000000000003421474673237200142340ustar00rootroot00000000000000# metadata include LICENSE include WORKSPACE include CONTRIBUTING.md include README.md # python package requirements include requirements*.txt # tree files recursive-include . CMakeLists.txt *.cc *.cpp *.h *.sh *.py *.cmake tree-0.1.9/README.md000066400000000000000000000017361474673237200137650ustar00rootroot00000000000000# Tree `tree` is a library for working with nested data structures. In a way, `tree` generalizes the builtin `map` function which only supports flat sequences, and allows to apply a function to each "leaf" preserving the overall structure. ```python >>> import tree >>> structure = [[1], [[[2, 3]]], [4]] >>> tree.flatten(structure) [1, 2, 3, 4] >>> tree.map_structure(lambda v: v**2, structure) [[1], [[[4, 9]]], [16]] ``` `tree` is backed by an optimized C++ implementation suitable for use in demanding applications, such as machine learning models. ## Installation From PyPI: ```shell $ pip install dm-tree ``` Directly from github using pip: ```shell $ pip install git+git://github.com/deepmind/tree.git ``` Build from source: ```shell $ python setup.py install ``` ## Support If you are having issues, please let us know by filing an issue on our [issue tracker](https://github.com/deepmind/tree/issues). ## License The project is licensed under the Apache 2.0 license. tree-0.1.9/docs/000077500000000000000000000000001474673237200134275ustar00rootroot00000000000000tree-0.1.9/docs/Makefile000066400000000000000000000011051474673237200150640ustar00rootroot00000000000000# Minimal makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build SOURCEDIR = . BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) tree-0.1.9/docs/api.rst000066400000000000000000000023611474673237200147340ustar00rootroot00000000000000############# API Reference ############# All ``tree`` functions operate on nested tree-like structures. A *structure* is recursively defined as:: Structure = Union[ Any, Sequence['Structure'], Mapping[Any, 'Structure'], 'AnyNamedTuple', ] .. TODO(slebedev): Support @dataclass classes if we make @attr.s .. support public. A single (non-nested) Python object is a perfectly valid structure:: >>> tree.map_structure(lambda v: v * 2, 42) 84 >>> tree.flatten(42) [42] You could check whether a structure is actually nested via :func:`~tree.is_nested`:: >>> tree.is_nested(42) False >>> tree.is_nested([42]) True Note that ``tree`` only supports acyclic structures. The behavior for structures with cycle references is undefined. .. currentmodule:: tree .. autofunction:: is_nested .. autofunction:: assert_same_structure .. autofunction:: unflatten_as .. autofunction:: flatten .. autofunction:: flatten_up_to .. autofunction:: flatten_with_path .. autofunction:: flatten_with_path_up_to .. autofunction:: map_structure .. autofunction:: map_structure_up_to .. autofunction:: map_structure_with_path .. autofunction:: map_structure_with_path_up_to .. autofunction:: traverse .. autodata:: MAP_TO_NONE tree-0.1.9/docs/changes.rst000066400000000000000000000024541474673237200155760ustar00rootroot00000000000000######### Changelog ######### Version 0.1.9 ============= Released 2025-01-30 * Dropped support for Python <3.10. Version 0.1.8 ============= Released 2022-12-19 * Bumped pybind11 to v2.10.1 to support Python 3.11. * Dropped support for Python 3.6. Version 0.1.7 ============= Released 2022-04-10 * The build is now done via CMake instead of Bazel. Version 0.1.6 ============= Released 2021-04-12 * Dropped support for Python 2.X. * Added a generalization of ``tree.traverse`` which keeps track of the current path during traversal. Version 0.1.5 ============= Released 2020-04-30 * Added a new function ``tree.traverse`` which allows to traverse a nested structure and apply a function to each subtree. Version 0.1.4 ============= Released 2020-03-27 * Added support for ``types.MappingProxyType`` on Python 3.X. Version 0.1.3 ============= Released 2020-01-30 * Fixed ``ImportError`` when ``wrapt`` was not available. Version 0.1.2 ============= Released 2020-01-29 * Added support for ``wrapt.ObjectWrapper`` objects. * Added ``StructureKV[K, V]`` and ``Structure = Structure[Text, V]`` types. Version 0.1.1 ============= Released 2019-11-07 * Ensured that the produced Linux wheels are manylinux2010-compatible. Version 0.1.0 ============= Released 2019-11-05 * Initial public release. tree-0.1.9/docs/conf.py000066400000000000000000000076421474673237200147370ustar00rootroot00000000000000# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Configuration file for the Sphinx documentation builder.""" # This file only contains a selection of the most common options. For a full # list see the documentation: # http://www.sphinx-doc.org/en/master/config # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # pylint: disable=g-bad-import-order # pylint: disable=g-import-not-at-top import datetime import inspect import os import sys sys.path.insert(0, os.path.abspath('../')) import tree # -- Project information ----------------------------------------------------- project = 'Tree' copyright = f'{datetime.date.today().year}, DeepMind' # pylint: disable=redefined-builtin author = 'DeepMind' # -- General configuration --------------------------------------------------- master_doc = 'index' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx.ext.linkcode', 'sphinx.ext.napoleon', 'sphinx.ext.doctest' ] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # -- Options for autodoc ----------------------------------------------------- autodoc_default_options = { 'member-order': 'bysource', 'special-members': True, 'exclude-members': '__repr__, __str__, __weakref__', } # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = 'sphinx_rtd_theme' html_theme_options = { # 'collapse_navigation': False, # 'sticky_navigation': False, } # -- Options for doctest ----------------------------------------------------- doctest_global_setup = ''' import collections import numpy as np import tree ''' # -- Source code links ------------------------------------------------------- def linkcode_resolve(domain, info): """Resolve a GitHub URL corresponding to Python object.""" if domain != 'py': return None try: mod = sys.modules[info['module']] except ImportError: return None obj = mod try: for attr in info['fullname'].split('.'): obj = getattr(obj, attr) except AttributeError: return None else: obj = inspect.unwrap(obj) try: filename = inspect.getsourcefile(obj) except TypeError: return None try: source, lineno = inspect.getsourcelines(obj) except OSError: return None # TODO(slebedev): support tags after we release an initial version. return 'https://github.com/deepmind/tree/blob/master/tree/%s#L%d#L%d' % ( os.path.relpath(filename, start=os.path.dirname( tree.__file__)), lineno, lineno + len(source) - 1) tree-0.1.9/docs/index.rst000066400000000000000000000015711474673237200152740ustar00rootroot00000000000000################## Tree Documentation ################## .. toctree:: :maxdepth: 2 :hidden: api changes recipes ``tree`` is a library for working with nested data structures. In a way, ``tree`` generalizes the builtin :func:`map` function which only supports flat sequences, and allows to apply a function to each "leaf" preserving the overall structure. Here's a quick example:: >>> tree.map_structure(lambda v: v**2, [[1], [[[2, 3]]], [4]]) [[1], [[[4, 9]]], [16]] .. note:: ``tree`` has originally been part of TensorFlow and is available as ``tf.nest``. Installation ============ Install ``tree`` by running:: $ pip install dm-tree Support ======= If you are having issues, please let us know by filing an issue on our `issue tracker `_. License ======= Tree is licensed under the Apache 2.0 License. tree-0.1.9/docs/recipes.rst000066400000000000000000000020571474673237200156170ustar00rootroot00000000000000############ Recipes ############ Concatenate nested array structures =================================== >>> tree.map_structure(lambda *args: np.concatenate(args, axis=1), ... {'a': np.ones((2, 1))}, ... {'a': np.zeros((2, 1))}) {'a': array([[1., 0.], [1., 0.]])} >>> tree.map_structure(lambda *args: np.concatenate(args, axis=0), ... {'a': np.ones((2, 1))}, ... {'a': np.zeros((2, 1))}) {'a': array([[1.], [1.], [0.], [0.]])} Exclude "meta" keys while mapping across structures =================================================== >>> d1 = {'key_to_exclude': None, 'a': 1} >>> d2 = {'key_to_exclude': None, 'a': 2} >>> d3 = {'a': 3} >>> tree.map_structure_up_to({'a': True}, lambda x, y, z: x+y+z, d1, d2, d3) {'a': 6} Broadcast a value across a reference structure ============================================== >>> reference_tree = {'a': 1, 'b': (2, 3)} >>> value = np.inf >>> tree.map_structure(lambda _: value, reference_tree) {'a': inf, 'b': (inf, inf)} tree-0.1.9/docs/requirements.txt000066400000000000000000000000461474673237200167130ustar00rootroot00000000000000sphinx>=2.0.1 sphinx_rtd_theme>=0.4.3 tree-0.1.9/readthedocs.yml000066400000000000000000000004261474673237200155110ustar00rootroot00000000000000# Read the Docs configuration file # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details version: 2 sphinx: builder: html configuration: docs/conf.py fail_on_warning: false python: version: 3.7 install: - requirements: docs/requirements.txt tree-0.1.9/setup.py000066400000000000000000000132331474673237200142130ustar00rootroot00000000000000# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Setup for pip package.""" import os import platform import shutil import subprocess import sys import sysconfig import setuptools from setuptools.command import build_ext here = os.path.dirname(os.path.abspath(__file__)) def _get_tree_version(): """Parse the version string from tree/__init__.py.""" with open(os.path.join(here, 'tree', '__init__.py')) as f: try: version_line = next(line for line in f if line.startswith('__version__')) except StopIteration: raise ValueError('__version__ not defined in tree/__init__.py') else: ns = {} exec(version_line, ns) # pylint: disable=exec-used return ns['__version__'] class CMakeExtension(setuptools.Extension): """An extension with no sources. We do not want distutils to handle any of the compilation (instead we rely on CMake), so we always pass an empty list to the constructor. """ def __init__(self, name, source_dir=''): super().__init__(name, sources=[]) self.source_dir = os.path.abspath(source_dir) class BuildCMakeExtension(build_ext.build_ext): """Our custom build_ext command. Uses CMake to build extensions instead of a bare compiler (e.g. gcc, clang). """ def run(self): self._check_build_environment() for ext in self.extensions: self.build_extension(ext) def _check_build_environment(self): """Check for required build tools: CMake, C++ compiler, and python dev.""" try: subprocess.check_call(['cmake', '--version']) except OSError as e: ext_names = ', '.join(e.name for e in self.extensions) raise RuntimeError( f'CMake must be installed to build the following extensions: {ext_names}' ) from e print('Found CMake') def build_extension(self, ext): extension_dir = os.path.abspath( os.path.dirname(self.get_ext_fullpath(ext.name))) build_cfg = 'Debug' if self.debug else 'Release' cmake_args = [ f'-DPython3_ROOT_DIR={sys.prefix}', f'-DPython3_EXECUTABLE={sys.executable}', f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extension_dir}', f'-DCMAKE_BUILD_TYPE={build_cfg}' ] if platform.system() != 'Windows': cmake_args.extend([ f'-DPython3_LIBRARY={sysconfig.get_paths()["stdlib"]}', f'-DPython3_INCLUDE_DIR={sysconfig.get_paths()["include"]}', ]) if platform.system() == 'Darwin' and os.environ.get('ARCHFLAGS'): osx_archs = [] if '-arch x86_64' in os.environ['ARCHFLAGS']: osx_archs.append('x86_64') if '-arch arm64' in os.environ['ARCHFLAGS']: osx_archs.append('arm64') cmake_args.append(f'-DCMAKE_OSX_ARCHITECTURES={";".join(osx_archs)}') os.makedirs(self.build_temp, exist_ok=True) subprocess.check_call( ['cmake', '-S', ext.source_dir, '-B', self.build_temp] + cmake_args) num_jobs = () if self.parallel: num_jobs = (f'-j{self.parallel}',) subprocess.check_call([ 'cmake', '--build', self.build_temp, *num_jobs, '--config', build_cfg ]) # Force output to /. Amends CMake multigenerator output paths # on Windows and avoids Debug/ and Release/ subdirs, which is CMake default. tree_dir = os.path.join(extension_dir, 'tree') # pylint:disable=unreachable for cfg in ('Release', 'Debug'): cfg_dir = os.path.join(extension_dir, cfg) if os.path.isdir(cfg_dir): for f in os.listdir(cfg_dir): shutil.move(os.path.join(cfg_dir, f), tree_dir) setuptools.setup( name='dm-tree', version=_get_tree_version(), url='https://github.com/deepmind/tree', description='Tree is a library for working with nested data structures.', author='DeepMind', author_email='tree-copybara@google.com', long_description=open(os.path.join(here, 'README.md')).read(), long_description_content_type='text/markdown', # Contained modules and scripts. packages=setuptools.find_packages(), python_requires='>=3.10', install_requires=[ 'absl-py>=0.6.1', 'attrs>=18.2.0', 'numpy>=1.21', "numpy>=1.21.2; python_version>='3.10'", "numpy>=1.23.3; python_version>='3.11'", "numpy>=1.26.0; python_version>='3.12'", "numpy>=2.1.0; python_version>='3.13'", 'wrapt>=1.11.2', ], test_suite='tree', cmdclass=dict(build_ext=BuildCMakeExtension), ext_modules=[CMakeExtension('_tree', source_dir='tree')], zip_safe=False, # PyPI package information. classifiers=[ 'Development Status :: 4 - Beta', 'Intended Audience :: Developers', 'Intended Audience :: Science/Research', 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', 'Programming Language :: Python :: 3.13', 'Topic :: Scientific/Engineering :: Mathematics', 'Topic :: Software Development :: Libraries', ], license='Apache 2.0', keywords='tree nest flatten', ) tree-0.1.9/tree/000077500000000000000000000000001474673237200134365ustar00rootroot00000000000000tree-0.1.9/tree/CMakeLists.txt000066400000000000000000000077621474673237200162120ustar00rootroot00000000000000# Version >= 3.24 required for new `FindPython` module and `FIND_PACKAGE_ARGS` # keyword of `FetchContent` module. # https://cmake.org/cmake/help/v3.24/release/3.24.html cmake_minimum_required(VERSION 3.24) cmake_policy(SET CMP0135 NEW) project (tree LANGUAGES CXX) option(USE_SYSTEM_ABSEIL "Force use of system abseil-cpp" OFF) option(USE_SYSTEM_PYBIND11 "Force use of system pybind11" OFF) # Required for Python.h and python binding. find_package(Python3 COMPONENTS Interpreter Development) include_directories(SYSTEM ${Python3_INCLUDE_DIRS}) if(Python3_VERSION VERSION_LESS "3.6.0") message(FATAL_ERROR "Python found ${Python3_VERSION} < 3.6.0") endif() # Use C++14 standard. set(CMAKE_CXX_STANDARD 14 CACHE STRING "C++ version selection") # Position-independent code is needed for Python extension modules. set(CMAKE_POSITION_INDEPENDENT_CODE ON) # Set default build type. if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE RELEASE CACHE STRING "Choose the type of build: Debug Release." FORCE) endif() message("Current build type is: ${CMAKE_BUILD_TYPE}") message("PROJECT_BINARY_DIR is: ${PROJECT_BINARY_DIR}") if (NOT (WIN32 OR MSVC)) if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") # Basic build for debugging (default). # -Og enables optimizations that do not interfere with debugging. set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Og") endif() if(${CMAKE_BUILD_TYPE} STREQUAL "Release") # Optimized release build: turn off debug runtime checks # and turn on highest speed optimizations. set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNDEBUG -O3") endif() endif() if(APPLE) # On MacOS: # -undefined dynamic_lookup is necessary for pybind11 linking set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-everything -w -undefined dynamic_lookup") # On MacOS, we need this so that CMake will use the right Python if the user # has a virtual environment active set (CMAKE_FIND_FRAMEWORK LAST) endif() # Use `FetchContent` module to manage all external dependencies (i.e. # abseil-cpp and pybind11). include(FetchContent) # Needed to disable Abseil tests. set(BUILD_TESTING OFF) # Try to find abseil-cpp package system-wide first. if (USE_SYSTEM_ABSEIL) message(STATUS "Use system abseil-cpp: ${USE_SYSTEM_ABSEIL}") set(ABSEIL_FIND_PACKAGE_ARGS FIND_PACKAGE_ARGS) endif() # Include abseil-cpp. set(ABSEIL_REPO https://github.com/abseil/abseil-cpp) set(ABSEIL_CMAKE_ARGS "-DCMAKE_INSTALL_PREFIX=${CMAKE_SOURCE_DIR}/abseil-cpp" "-DCMAKE_CXX_STANDARD=${CMAKE_CXX_STANDARD}" "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" "-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}" "-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}" "-DCMAKE_POSITION_INDEPENDENT_CODE=${CMAKE_POSITION_INDEPENDENT_CODE}" "-DLIBRARY_OUTPUT_PATH=${CMAKE_SOURCE_DIR}/abseil-cpp/lib" "-DABSL_PROPAGATE_CXX_STD=ON") if(DEFINED CMAKE_OSX_ARCHITECTURES) set(ABSEIL_CMAKE_ARGS ${ABSEIL_CMAKE_ARGS} "-DCMAKE_OSX_ARCHITECTURES=${CMAKE_OSX_ARCHITECTURES}") endif() FetchContent_Declare( absl URL ${ABSEIL_REPO}/archive/refs/tags/20220623.2.tar.gz URL_HASH SHA256=773652c0fc276bcd5c461668dc112d0e3b6cde499600bfe3499c5fdda4ed4a5b CMAKE_ARGS ${ABSEIL_CMAKE_ARGS} EXCLUDE_FROM_ALL ${ABSEIL_FIND_PACKAGE_ARGS}) # Try to find pybind11 package system-wide first. if (USE_SYSTEM_PYBIND11) message(STATUS "Use system pybind11: ${USE_SYSTEM_PYBIND11}") set(PYBIND11_FIND_PACKAGE_ARGS FIND_PACKAGE_ARGS) endif() FetchContent_Declare( pybind11 URL https://github.com/pybind/pybind11/archive/refs/tags/v2.10.1.tar.gz URL_HASH SHA256=111014b516b625083bef701df7880f78c2243835abdb263065b6b59b960b6bad ${PYBIND11_FIND_PACKAGE_ARGS}) FetchContent_MakeAvailable(absl pybind11) # Define pybind11 tree module. pybind11_add_module(_tree tree.h tree.cc) target_link_libraries( _tree PRIVATE absl::int128 absl::raw_hash_set absl::raw_logging_internal absl::strings absl::throw_delegate) # Make the module private to tree package. set_target_properties(_tree PROPERTIES OUTPUT_NAME tree/_tree) tree-0.1.9/tree/__init__.py000066400000000000000000001015121474673237200155470ustar00rootroot00000000000000# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Functions for working with nested data structures.""" from collections import abc as collections_abc import logging import sys from typing import Mapping, Sequence, TypeVar, Union from .sequence import _is_attrs from .sequence import _is_namedtuple from .sequence import _sequence_like from .sequence import _sorted # pylint: disable=g-import-not-at-top try: import wrapt ObjectProxy = wrapt.ObjectProxy except ImportError: class ObjectProxy(object): """Stub-class for `wrapt.ObjectProxy``.""" try: from tree import _tree except ImportError: if "sphinx" not in sys.modules: raise _tree = None # pylint: enable=g-import-not-at-top __all__ = [ "is_nested", "assert_same_structure", "unflatten_as", "flatten", "flatten_up_to", "flatten_with_path", "flatten_with_path_up_to", "map_structure", "map_structure_up_to", "map_structure_with_path", "map_structure_with_path_up_to", "traverse", "MAP_TO_NONE", ] __version__ = "0.1.9" # Note: this is *not* the same as `six.string_types`, which in Python3 is just # `(str,)` (i.e. it does not include byte strings). _TEXT_OR_BYTES = (str, bytes) _SHALLOW_TREE_HAS_INVALID_KEYS = ( "The shallow_tree's keys are not a subset of the input_tree's keys. The " "shallow_tree has the following keys that are not in the input_tree: {}.") _STRUCTURES_HAVE_MISMATCHING_TYPES = ( "The two structures don't have the same sequence type. Input structure has " "type {input_type}, while shallow structure has type {shallow_type}.") _STRUCTURES_HAVE_MISMATCHING_LENGTHS = ( "The two structures don't have the same sequence length. Input " "structure has length {input_length}, while shallow structure has length " "{shallow_length}." ) _INPUT_TREE_SMALLER_THAN_SHALLOW_TREE = ( "The input_tree has fewer elements than the shallow_tree. Input structure " "has length {input_size}, while shallow structure has length " "{shallow_size}.") _IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ = ( "If shallow structure is a sequence, input must also be a sequence. " "Input has type: {}.") _IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH = ( "If shallow structure is a sequence, input must also be a sequence. " "Input at path: {path} has type: {input_type}.") K = TypeVar("K") V = TypeVar("V") # A generic monomorphic structure type, e.g. ``StructureKV[str, int]`` # is an arbitrarily nested structure where keys must be of type ``str`` # and values are integers. StructureKV = Union[ Sequence["StructureKV[K, V]"], Mapping[K, "StructureKV[K, V]"], V, ] Structure = StructureKV[str, V] def _get_attrs_items(obj): """Returns a list of (name, value) pairs from an attrs instance. The list will be sorted by name. Args: obj: an object. Returns: A list of (attr_name, attr_value) pairs. """ return [(attr.name, getattr(obj, attr.name)) for attr in obj.__class__.__attrs_attrs__] def _yield_value(iterable): for _, v in _yield_sorted_items(iterable): yield v def _yield_sorted_items(iterable): """Yield (key, value) pairs for `iterable` in a deterministic order. For Sequences, the key will be an int, the array index of a value. For Mappings, the key will be the dictionary key. For objects (e.g. namedtuples), the key will be the attribute name. In all cases, the keys will be iterated in sorted order. Args: iterable: an iterable. Yields: The iterable's (key, value) pairs, in order of sorted keys. """ if isinstance(iterable, collections_abc.Mapping): # Iterate through dictionaries in a deterministic order by sorting the # keys. Notice this means that we ignore the original order of `OrderedDict` # instances. This is intentional, to avoid potential bugs caused by mixing # ordered and plain dicts (e.g., flattening a dict but using a # corresponding `OrderedDict` to pack it back). for key in _sorted(iterable): yield key, iterable[key] elif _is_attrs(iterable): for item in _get_attrs_items(iterable): yield item elif _is_namedtuple(iterable): for field in iterable._fields: yield (field, getattr(iterable, field)) else: for item in enumerate(iterable): yield item def _num_elements(structure): if _is_attrs(structure): return len(getattr(structure.__class__, "__attrs_attrs__")) else: return len(structure) def is_nested(structure): """Checks if a given structure is nested. >>> tree.is_nested(42) False >>> tree.is_nested({"foo": 42}) True Args: structure: A structure to check. Returns: `True` if a given structure is nested, i.e. is a sequence, a mapping, or a namedtuple, and `False` otherwise. """ return _tree.is_sequence(structure) def flatten(structure): r"""Flattens a possibly nested structure into a list. >>> tree.flatten([[1, 2, 3], [4, [5], [[6]]]]) [1, 2, 3, 4, 5, 6] If `structure` is not nested, the result is a single-element list. >>> tree.flatten(None) [None] >>> tree.flatten(1) [1] In the case of dict instances, the sequence consists of the values, sorted by key to ensure deterministic behavior. This is true also for :class:`~collections.OrderedDict` instances: their sequence order is ignored, the sorting order of keys is used instead. The same convention is followed in :func:`~tree.unflatten`. This correctly unflattens dicts and ``OrderedDict``\ s after they have been flattened, and also allows flattening an ``OrderedDict`` and then unflattening it back using a corresponding plain dict, or vice-versa. Dictionaries with non-sortable keys cannot be flattened. >>> tree.flatten({100: 'world!', 6: 'Hello'}) ['Hello', 'world!'] Args: structure: An arbitrarily nested structure. Returns: A list, the flattened version of the input `structure`. Raises: TypeError: If `structure` is or contains a mapping with non-sortable keys. """ return _tree.flatten(structure) class _DotString(object): def __str__(self): return "." def __repr__(self): return "." _DOT = _DotString() def assert_same_structure(a, b, check_types=True): """Asserts that two structures are nested in the same way. >>> tree.assert_same_structure([(0, 1)], [(2, 3)]) Note that namedtuples with identical name and fields are always considered to have the same shallow structure (even with `check_types=True`). >>> Foo = collections.namedtuple('Foo', ['a', 'b']) >>> AlsoFoo = collections.namedtuple('Foo', ['a', 'b']) >>> tree.assert_same_structure(Foo(0, 1), AlsoFoo(2, 3)) Named tuples with different names are considered to have different shallow structures: >>> Bar = collections.namedtuple('Bar', ['a', 'b']) >>> tree.assert_same_structure(Foo(0, 1), Bar(2, 3)) Traceback (most recent call last): ... TypeError: The two structures don't have the same nested structure. ... Args: a: an arbitrarily nested structure. b: an arbitrarily nested structure. check_types: if `True` (default) types of sequences are checked as well, including the keys of dictionaries. If set to `False`, for example a list and a tuple of objects will look the same if they have the same size. Note that namedtuples with identical name and fields are always considered to have the same shallow structure. Raises: ValueError: If the two structures do not have the same number of elements or if the two structures are not nested in the same way. TypeError: If the two structures differ in the type of sequence in any of their substructures. Only possible if `check_types` is `True`. """ try: _tree.assert_same_structure(a, b, check_types) except (ValueError, TypeError) as e: str1 = str(map_structure(lambda _: _DOT, a)) str2 = str(map_structure(lambda _: _DOT, b)) raise type(e)("%s\n" "Entire first structure:\n%s\n" "Entire second structure:\n%s" % (e, str1, str2)) def _packed_nest_with_indices(structure, flat, index): """Helper function for ``unflatten_as``. Args: structure: Substructure (list / tuple / dict) to mimic. flat: Flattened values to output substructure for. index: Index at which to start reading from flat. Returns: The tuple (new_index, child), where: * new_index - the updated index into `flat` having processed `structure`. * packed - the subset of `flat` corresponding to `structure`, having started at `index`, and packed into the same nested format. Raises: ValueError: if `structure` contains more elements than `flat` (assuming indexing starts from `index`). """ packed = [] for s in _yield_value(structure): if is_nested(s): new_index, child = _packed_nest_with_indices(s, flat, index) packed.append(_sequence_like(s, child)) index = new_index else: packed.append(flat[index]) index += 1 return index, packed def unflatten_as(structure, flat_sequence): r"""Unflattens a sequence into a given structure. >>> tree.unflatten_as([[1, 2], [[3], [4]]], [5, 6, 7, 8]) [[5, 6], [[7], [8]]] If `structure` is a scalar, `flat_sequence` must be a single-element list; in this case the return value is ``flat_sequence[0]``. >>> tree.unflatten_as(None, [1]) 1 If `structure` is or contains a dict instance, the keys will be sorted to pack the flat sequence in deterministic order. This is true also for :class:`~collections.OrderedDict` instances: their sequence order is ignored, the sorting order of keys is used instead. The same convention is followed in :func:`~tree.flatten`. This correctly unflattens dicts and ``OrderedDict``\ s after they have been flattened, and also allows flattening an ``OrderedDict`` and then unflattening it back using a corresponding plain dict, or vice-versa. Dictionaries with non-sortable keys cannot be unflattened. >>> tree.unflatten_as({1: None, 2: None}, ['Hello', 'world!']) {1: 'Hello', 2: 'world!'} Args: structure: Arbitrarily nested structure. flat_sequence: Sequence to unflatten. Returns: `flat_sequence` unflattened into `structure`. Raises: ValueError: If `flat_sequence` and `structure` have different element counts. TypeError: If `structure` is or contains a mapping with non-sortable keys. """ if not is_nested(flat_sequence): raise TypeError("flat_sequence must be a sequence not a {}:\n{}".format( type(flat_sequence), flat_sequence)) if not is_nested(structure): if len(flat_sequence) != 1: raise ValueError("Structure is a scalar but len(flat_sequence) == %d > 1" % len(flat_sequence)) return flat_sequence[0] flat_structure = flatten(structure) if len(flat_structure) != len(flat_sequence): raise ValueError( "Could not pack sequence. Structure had %d elements, but flat_sequence " "had %d elements. Structure: %s, flat_sequence: %s." % (len(flat_structure), len(flat_sequence), structure, flat_sequence)) _, packed = _packed_nest_with_indices(structure, flat_sequence, 0) return _sequence_like(structure, packed) def map_structure(func, *structures, **kwargs): # pylint: disable=redefined-builtin """Maps `func` through given structures. >>> structure = [[1], [2], [3]] >>> tree.map_structure(lambda v: v**2, structure) [[1], [4], [9]] >>> tree.map_structure(lambda x, y: x * y, structure, structure) [[1], [4], [9]] >>> Foo = collections.namedtuple('Foo', ['a', 'b']) >>> structure = Foo(a=1, b=2) >>> tree.map_structure(lambda v: v * 2, structure) Foo(a=2, b=4) Args: func: A callable that accepts as many arguments as there are structures. *structures: Arbitrarily nested structures of the same layout. **kwargs: The only valid keyword argument is `check_types`. If `True` (default) the types of components within the structures have to be match, e.g. ``tree.map_structure(func, [1], (1,))`` will raise a `TypeError`, otherwise this is not enforced. Note that namedtuples with identical name and fields are considered to be the same type. Returns: A new structure with the same layout as the given ones. If the `structures` have components of varying types, the resulting structure will use the same types as ``structures[0]``. Raises: TypeError: If `func` is not callable. ValueError: If the two structures do not have the same number of elements or if the two structures are not nested in the same way. TypeError: If `check_types` is `True` and any two `structures` differ in the types of their components. ValueError: If no structures were given or if a keyword argument other than `check_types` is provided. """ if not callable(func): raise TypeError("func must be callable, got: %s" % func) if not structures: raise ValueError("Must provide at least one structure") check_types = kwargs.pop("check_types", True) if kwargs: raise ValueError( "Only valid keyword arguments are `check_types` " "not: `%s`" % ("`, `".join(kwargs.keys()))) for other in structures[1:]: assert_same_structure(structures[0], other, check_types=check_types) return unflatten_as(structures[0], [func(*args) for args in zip(*map(flatten, structures))]) def map_structure_with_path(func, *structures, **kwargs): """Maps `func` through given structures. This is a variant of :func:`~tree.map_structure` which accumulates a *path* while mapping through the structures. A path is a tuple of indices and/or keys which uniquely identifies the positions of the arguments passed to `func`. >>> tree.map_structure_with_path( ... lambda path, v: (path, v**2), ... [{"foo": 42}]) [{'foo': ((0, 'foo'), 1764)}] Args: func: A callable that accepts a path and as many arguments as there are structures. *structures: Arbitrarily nested structures of the same layout. **kwargs: The only valid keyword argument is `check_types`. If `True` (default) the types of components within the structures have to be match, e.g. ``tree.map_structure_with_path(func, [1], (1,))`` will raise a `TypeError`, otherwise this is not enforced. Note that namedtuples with identical name and fields are considered to be the same type. Returns: A new structure with the same layout as the given ones. If the `structures` have components of varying types, the resulting structure will use the same types as ``structures[0]``. Raises: TypeError: If `func` is not callable or if the `structures` do not have the same layout. TypeError: If `check_types` is `True` and any two `structures` differ in the types of their components. ValueError: If no structures were given or if a keyword argument other than `check_types` is provided. """ return map_structure_with_path_up_to(structures[0], func, *structures, **kwargs) def _yield_flat_up_to(shallow_tree, input_tree, path=()): """Yields (path, value) pairs of input_tree flattened up to shallow_tree. Args: shallow_tree: Nested structure. Traverse no further than its leaf nodes. input_tree: Nested structure. Return the paths and values from this tree. Must have the same upper structure as shallow_tree. path: Tuple. Optional argument, only used when recursing. The path from the root of the original shallow_tree, down to the root of the shallow_tree arg of this recursive call. Yields: Pairs of (path, value), where path the tuple path of a leaf node in shallow_tree, and value is the value of the corresponding node in input_tree. """ if (isinstance(shallow_tree, _TEXT_OR_BYTES) or not (isinstance(shallow_tree, (collections_abc.Mapping, collections_abc.Sequence)) or _is_namedtuple(shallow_tree) or _is_attrs(shallow_tree))): yield (path, input_tree) else: input_tree = dict(_yield_sorted_items(input_tree)) for shallow_key, shallow_subtree in _yield_sorted_items(shallow_tree): subpath = path + (shallow_key,) input_subtree = input_tree[shallow_key] for leaf_path, leaf_value in _yield_flat_up_to(shallow_subtree, input_subtree, path=subpath): yield (leaf_path, leaf_value) def _multiyield_flat_up_to(shallow_tree, *input_trees): """Same as `_yield_flat_up_to`, but takes multiple input trees.""" zipped_iterators = zip(*[_yield_flat_up_to(shallow_tree, input_tree) for input_tree in input_trees]) try: for paths_and_values in zipped_iterators: paths, values = zip(*paths_and_values) yield paths[:1] + values except KeyError as e: paths = locals().get("paths", ((),)) raise ValueError(f"Could not find key '{e.args[0]}' in some `input_trees`. " "Please ensure the structure of all `input_trees` are " "compatible with `shallow_tree`. The last valid path " f"yielded was {paths[0]}.") from e def _assert_shallow_structure(shallow_tree, input_tree, path=None, check_types=True): """Asserts that `shallow_tree` is a shallow structure of `input_tree`. That is, this function recursively tests if each key in shallow_tree has its corresponding key in input_tree. Examples: The following code will raise an exception: >>> shallow_tree = {"a": "A", "b": "B"} >>> input_tree = {"a": 1, "c": 2} >>> _assert_shallow_structure(shallow_tree, input_tree) Traceback (most recent call last): ... ValueError: The shallow_tree's keys are not a subset of the input_tree's ... The following code will raise an exception: >>> shallow_tree = ["a", "b"] >>> input_tree = ["c", ["d", "e"], "f"] >>> _assert_shallow_structure(shallow_tree, input_tree) Traceback (most recent call last): ... ValueError: The two structures don't have the same sequence length. ... By setting check_types=False, we drop the requirement that corresponding nodes in shallow_tree and input_tree have to be the same type. Sequences are treated equivalently to Mappables that map integer keys (indices) to values. The following code will therefore not raise an exception: >>> _assert_shallow_structure({0: "foo"}, ["foo"], check_types=False) Args: shallow_tree: an arbitrarily nested structure. input_tree: an arbitrarily nested structure. path: if not `None`, a tuple containing the current path in the nested structure. This is only used for more informative errror messages. check_types: if `True` (default) the sequence types of `shallow_tree` and `input_tree` have to be the same. Raises: TypeError: If `shallow_tree` is a sequence but `input_tree` is not. TypeError: If the sequence types of `shallow_tree` are different from `input_tree`. Only raised if `check_types` is `True`. ValueError: If the sequence lengths of `shallow_tree` are different from `input_tree`. """ if is_nested(shallow_tree): if not is_nested(input_tree): if path is not None: raise TypeError( _IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format( path=list(path), input_type=type(input_tree))) else: raise TypeError( _IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format( type(input_tree))) if isinstance(shallow_tree, ObjectProxy): shallow_type = type(shallow_tree.__wrapped__) else: shallow_type = type(shallow_tree) if check_types and not isinstance(input_tree, shallow_type): # Duck-typing means that nest should be fine with two different # namedtuples with identical name and fields. shallow_is_namedtuple = _is_namedtuple(shallow_tree, False) input_is_namedtuple = _is_namedtuple(input_tree, False) if shallow_is_namedtuple and input_is_namedtuple: # pylint: disable=protected-access if not _tree.same_namedtuples(shallow_tree, input_tree): raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format( input_type=type(input_tree), shallow_type=shallow_type)) # pylint: enable=protected-access elif not (isinstance(shallow_tree, collections_abc.Mapping) and isinstance(input_tree, collections_abc.Mapping)): raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format( input_type=type(input_tree), shallow_type=shallow_type)) if _num_elements(input_tree) != _num_elements(shallow_tree): raise ValueError( _STRUCTURES_HAVE_MISMATCHING_LENGTHS.format( input_length=_num_elements(input_tree), shallow_length=_num_elements(shallow_tree))) elif _num_elements(input_tree) < _num_elements(shallow_tree): raise ValueError( _INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format( input_size=_num_elements(input_tree), shallow_size=_num_elements(shallow_tree))) shallow_iter = _yield_sorted_items(shallow_tree) input_iter = _yield_sorted_items(input_tree) def get_matching_input_branch(shallow_key): for input_key, input_branch in input_iter: if input_key == shallow_key: return input_branch raise ValueError(_SHALLOW_TREE_HAS_INVALID_KEYS.format([shallow_key])) for shallow_key, shallow_branch in shallow_iter: input_branch = get_matching_input_branch(shallow_key) _assert_shallow_structure( shallow_branch, input_branch, path + (shallow_key,) if path is not None else None, check_types=check_types) def flatten_up_to(shallow_structure, input_structure, check_types=True): """Flattens `input_structure` up to `shallow_structure`. All further nested components in `input_structure` are retained as-is. >>> structure = [[1, 1], [2, 2]] >>> tree.flatten_up_to([None, None], structure) [[1, 1], [2, 2]] >>> tree.flatten_up_to([None, [None, None]], structure) [[1, 1], 2, 2] If `shallow_structure` and `input_structure` are not nested, the result is a single-element list: >>> tree.flatten_up_to(42, 1) [1] >>> tree.flatten_up_to(42, [1, 2, 3]) [[1, 2, 3]] Args: shallow_structure: A structure with the same (but possibly more shallow) layout as `input_structure`. input_structure: An arbitrarily nested structure. check_types: If `True`, check that each node in shallow_tree has the same type as the corresponding node in `input_structure`. Returns: A list, the partially flattened version of `input_structure` wrt `shallow_structure`. Raises: TypeError: If the layout of `shallow_structure` does not match that of `input_structure`. TypeError: If `check_types` is `True` and `shallow_structure` and `input_structure` differ in the types of their components. """ _assert_shallow_structure( shallow_structure, input_structure, path=None, check_types=check_types) # Discard paths returned by _yield_flat_up_to. return [v for _, v in _yield_flat_up_to(shallow_structure, input_structure)] def flatten_with_path_up_to(shallow_structure, input_structure, check_types=True): """Flattens `input_structure` up to `shallow_structure`. This is a combination of :func:`~tree.flatten_up_to` and :func:`~tree.flatten_with_path` Args: shallow_structure: A structure with the same (but possibly more shallow) layout as `input_structure`. input_structure: An arbitrarily nested structure. check_types: If `True`, check that each node in shallow_tree has the same type as the corresponding node in `input_structure`. Returns: A list of ``(path, item)`` pairs corresponding to the partially flattened version of `input_structure` wrt `shallow_structure`. Raises: TypeError: If the layout of `shallow_structure` does not match that of `input_structure`. TypeError: If `input_structure` is or contains a mapping with non-sortable keys. TypeError: If `check_types` is `True` and `shallow_structure` and `input_structure` differ in the types of their components. """ _assert_shallow_structure( shallow_structure, input_structure, path=(), check_types=check_types) return list(_yield_flat_up_to(shallow_structure, input_structure)) def map_structure_up_to(shallow_structure, func, *structures, **kwargs): """Maps `func` through given structures up to `shallow_structure`. This is a variant of :func:`~tree.map_structure` which only maps the given structures up to `shallow_structure`. All further nested components are retained as-is. >>> structure = [[1, 1], [2, 2]] >>> tree.map_structure_up_to([None, None], len, structure) [2, 2] >>> tree.map_structure_up_to([None, [None, None]], str, structure) ['[1, 1]', ['2', '2']] Args: shallow_structure: A structure with layout common to all `structures`. func: A callable that accepts as many arguments as there are structures. *structures: Arbitrarily nested structures of the same layout. **kwargs: No valid keyword arguments. Raises: ValueError: If `func` is not callable or if `structures` have different layout or if the layout of `shallow_structure` does not match that of `structures` or if no structures were given. Returns: A new structure with the same layout as `shallow_structure`. """ return map_structure_with_path_up_to( shallow_structure, lambda _, *args: func(*args), # Discards path. *structures, **kwargs) def map_structure_with_path_up_to(shallow_structure, func, *structures, **kwargs): """Maps `func` through given structures up to `shallow_structure`. This is a combination of :func:`~tree.map_structure_up_to` and :func:`~tree.map_structure_with_path` Args: shallow_structure: A structure with layout common to all `structures`. func: A callable that accepts a path and as many arguments as there are structures. *structures: Arbitrarily nested structures of the same layout. **kwargs: No valid keyword arguments. Raises: ValueError: If `func` is not callable or if `structures` have different layout or if the layout of `shallow_structure` does not match that of `structures` or if no structures were given. Returns: Result of repeatedly applying `func`. Has the same structure layout as `shallow_tree`. """ if "check_types" in kwargs: logging.warning("The use of `check_types` is deprecated and does not have " "any effect.") del kwargs results = [] for path_and_values in _multiyield_flat_up_to(shallow_structure, *structures): results.append(func(*path_and_values)) return unflatten_as(shallow_structure, results) def flatten_with_path(structure): r"""Flattens a possibly nested structure into a list. This is a variant of :func:`~tree.flattens` which produces a list of pairs: ``(path, item)``. A path is a tuple of indices and/or keys which uniquely identifies the position of the corresponding ``item``. >>> tree.flatten_with_path([{"foo": 42}]) [((0, 'foo'), 42)] Args: structure: An arbitrarily nested structure. Returns: A list of ``(path, item)`` pairs corresponding to the flattened version of the input `structure`. Raises: TypeError: If ``structure`` is or contains a mapping with non-sortable keys. """ return list(_yield_flat_up_to(structure, structure)) #: Special value for use with :func:`traverse`. MAP_TO_NONE = object() def traverse(fn, structure, top_down=True): """Traverses the given nested structure, applying the given function. The traversal is depth-first. If ``top_down`` is True (default), parents are returned before their children (giving the option to avoid traversing into a sub-tree). >>> visited = [] >>> tree.traverse(visited.append, [(1, 2), [3], {"a": 4}], top_down=True) [(1, 2), [3], {'a': 4}] >>> visited [[(1, 2), [3], {'a': 4}], (1, 2), 1, 2, [3], 3, {'a': 4}, 4] >>> visited = [] >>> tree.traverse(visited.append, [(1, 2), [3], {"a": 4}], top_down=False) [(1, 2), [3], {'a': 4}] >>> visited [1, 2, (1, 2), 3, [3], 4, {'a': 4}, [(1, 2), [3], {'a': 4}]] Args: fn: The function to be applied to each sub-nest of the structure. When traversing top-down: If ``fn(subtree) is None`` the traversal continues into the sub-tree. If ``fn(subtree) is not None`` the traversal does not continue into the sub-tree. The sub-tree will be replaced by ``fn(subtree)`` in the returned structure (to replace the sub-tree with None, use the special value :data:`MAP_TO_NONE`). When traversing bottom-up: If ``fn(subtree) is None`` the traversed sub-tree is returned unaltered. If ``fn(subtree) is not None`` the sub-tree will be replaced by ``fn(subtree)`` in the returned structure (to replace the sub-tree with None, use the special value :data:`MAP_TO_NONE`). structure: The structure to traverse. top_down: If True, parent structures will be visited before their children. Returns: The structured output from the traversal. """ return traverse_with_path(lambda _, x: fn(x), structure, top_down=top_down) def traverse_with_path(fn, structure, top_down=True): """Traverses the given nested structure, applying the given function. The traversal is depth-first. If ``top_down`` is True (default), parents are returned before their children (giving the option to avoid traversing into a sub-tree). >>> visited = [] >>> tree.traverse_with_path( ... lambda path, subtree: visited.append((path, subtree)), ... [(1, 2), [3], {"a": 4}], ... top_down=True) [(1, 2), [3], {'a': 4}] >>> visited == [ ... ((), [(1, 2), [3], {'a': 4}]), ... ((0,), (1, 2)), ... ((0, 0), 1), ... ((0, 1), 2), ... ((1,), [3]), ... ((1, 0), 3), ... ((2,), {'a': 4}), ... ((2, 'a'), 4)] True >>> visited = [] >>> tree.traverse_with_path( ... lambda path, subtree: visited.append((path, subtree)), ... [(1, 2), [3], {"a": 4}], ... top_down=False) [(1, 2), [3], {'a': 4}] >>> visited == [ ... ((0, 0), 1), ... ((0, 1), 2), ... ((0,), (1, 2)), ... ((1, 0), 3), ... ((1,), [3]), ... ((2, 'a'), 4), ... ((2,), {'a': 4}), ... ((), [(1, 2), [3], {'a': 4}])] True Args: fn: The function to be applied to the path to each sub-nest of the structure and the sub-nest value. When traversing top-down: If ``fn(path, subtree) is None`` the traversal continues into the sub-tree. If ``fn(path, subtree) is not None`` the traversal does not continue into the sub-tree. The sub-tree will be replaced by ``fn(path, subtree)`` in the returned structure (to replace the sub-tree with None, use the special value :data:`MAP_TO_NONE`). When traversing bottom-up: If ``fn(path, subtree) is None`` the traversed sub-tree is returned unaltered. If ``fn(path, subtree) is not None`` the sub-tree will be replaced by ``fn(path, subtree)`` in the returned structure (to replace the sub-tree with None, use the special value :data:`MAP_TO_NONE`). structure: The structure to traverse. top_down: If True, parent structures will be visited before their children. Returns: The structured output from the traversal. """ def traverse_impl(path, structure): """Recursive traversal implementation.""" def subtree_fn(item): subtree_path, subtree = item return traverse_impl(path + (subtree_path,), subtree) def traverse_subtrees(): if is_nested(structure): return _sequence_like(structure, map(subtree_fn, _yield_sorted_items(structure))) else: return structure if top_down: ret = fn(path, structure) if ret is None: return traverse_subtrees() elif ret is MAP_TO_NONE: return None else: return ret else: traversed_structure = traverse_subtrees() ret = fn(path, traversed_structure) if ret is None: return traversed_structure elif ret is MAP_TO_NONE: return None else: return ret return traverse_impl((), structure) tree-0.1.9/tree/sequence.py000066400000000000000000000100321474673237200156140ustar00rootroot00000000000000# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Contains _sequence_like and helpers for sequence data structures.""" import collections from collections import abc as collections_abc import types from tree import _tree # pylint: disable=g-import-not-at-top try: import wrapt ObjectProxy = wrapt.ObjectProxy except ImportError: class ObjectProxy(object): """Stub-class for `wrapt.ObjectProxy``.""" def _sorted(dictionary): """Returns a sorted list of the dict keys, with error if keys not sortable.""" try: return sorted(dictionary) except TypeError: raise TypeError("tree only supports dicts with sortable keys.") def _is_attrs(instance): return _tree.is_attrs(instance) def _is_namedtuple(instance, strict=False): """Returns True iff `instance` is a `namedtuple`. Args: instance: An instance of a Python object. strict: If True, `instance` is considered to be a `namedtuple` only if it is a "plain" namedtuple. For instance, a class inheriting from a `namedtuple` will be considered to be a `namedtuple` iff `strict=False`. Returns: True if `instance` is a `namedtuple`. """ return _tree.is_namedtuple(instance, strict) def _sequence_like(instance, args): """Converts the sequence `args` to the same type as `instance`. Args: instance: an instance of `tuple`, `list`, `namedtuple`, `dict`, or `collections.OrderedDict`. args: elements to be converted to the `instance` type. Returns: `args` with the type of `instance`. """ if isinstance(instance, (dict, collections_abc.Mapping)): # Pack dictionaries in a deterministic order by sorting the keys. # Notice this means that we ignore the original order of `OrderedDict` # instances. This is intentional, to avoid potential bugs caused by mixing # ordered and plain dicts (e.g., flattening a dict but using a # corresponding `OrderedDict` to pack it back). result = dict(zip(_sorted(instance), args)) keys_and_values = ((key, result[key]) for key in instance) if isinstance(instance, collections.defaultdict): # `defaultdict` requires a default factory as the first argument. return type(instance)(instance.default_factory, keys_and_values) elif isinstance(instance, types.MappingProxyType): # MappingProxyType requires a dict to proxy to. return type(instance)(dict(keys_and_values)) else: return type(instance)(keys_and_values) elif isinstance(instance, collections_abc.MappingView): # We can't directly construct mapping views, so we create a list instead return list(args) elif _is_namedtuple(instance) or _is_attrs(instance): if isinstance(instance, ObjectProxy): instance_type = type(instance.__wrapped__) else: instance_type = type(instance) try: if _is_attrs(instance): return instance_type( **{ attr.name: arg for attr, arg in zip(instance_type.__attrs_attrs__, args) }) else: return instance_type(*args) except Exception as e: raise TypeError( f"Couldn't traverse {instance!r} with arguments {args}") from e elif isinstance(instance, ObjectProxy): # For object proxies, first create the underlying type and then re-wrap it # in the proxy type. return type(instance)(_sequence_like(instance.__wrapped__, args)) else: # Not a namedtuple return type(instance)(args) tree-0.1.9/tree/tree.cc000066400000000000000000000603641474673237200147150ustar00rootroot00000000000000/* Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tree.h" #include #include #include #include // logging #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include #ifdef LOG #define LOG_WARNING(w) LOG(WARNING) << w; #else #include #define LOG_WARNING(w) std::cerr << w << "\n"; #endif #ifndef DCHECK #define DCHECK(stmt) #endif namespace py = pybind11; namespace tree { namespace { // PyObjectPtr wraps an underlying Python object and decrements the // reference count in the destructor. // // This class does not acquire the GIL in the destructor, so the GIL must be // held when the destructor is called. using PyObjectPtr = std::unique_ptr; const int kMaxItemsInCache = 1024; bool WarnedThatSetIsNotSequence = false; bool IsString(PyObject* o) { return PyBytes_Check(o) || PyByteArray_Check(o) || PyUnicode_Check(o); } // Equivalent to Python's 'o.__class__.__name__' // Note that '__class__' attribute is set only in new-style classes. // A lot of tensorflow code uses __class__ without checks, so it seems like // we only support new-style classes. absl::string_view GetClassName(PyObject* o) { // __class__ is equivalent to type() for new style classes. // type() is equivalent to PyObject_Type() // (https://docs.python.org/3.5/c-api/object.html#c.PyObject_Type) // PyObject_Type() is equivalent to o->ob_type except for Py_INCREF, which // we don't need here. PyTypeObject* type = o->ob_type; // __name__ is the value of `tp_name` after the last '.' // (https://docs.python.org/2/c-api/typeobj.html#c.PyTypeObject.tp_name) absl::string_view name(type->tp_name); size_t pos = name.rfind('.'); if (pos != absl::string_view::npos) { name.remove_prefix(pos + 1); } return name; } std::string PyObjectToString(PyObject* o) { if (o == nullptr) { return ""; } PyObject* str = PyObject_Str(o); if (str) { std::string s(PyUnicode_AsUTF8(str)); Py_DECREF(str); return absl::StrCat("type=", GetClassName(o), " str=", s); } else { return ""; } } class CachedTypeCheck { public: explicit CachedTypeCheck(std::function ternary_predicate) : ternary_predicate_(std::move(ternary_predicate)) {} ~CachedTypeCheck() { for (const auto& pair : type_to_sequence_map_) { Py_DECREF(pair.first); } } // Caches successful executions of the one-argument (PyObject*) callable // "ternary_predicate" based on the type of "o". -1 from the callable // indicates an unsuccessful check (not cached), 0 indicates that "o"'s type // does not match the predicate, and 1 indicates that it does. Used to avoid // calling back into Python for expensive isinstance checks. int CachedLookup(PyObject* o) { // Try not to return to Python - see if the type has already been seen // before. auto* type = Py_TYPE(o); { auto it = type_to_sequence_map_.find(type); if (it != type_to_sequence_map_.end()) { return it->second; } } int check_result = ternary_predicate_(o); if (check_result == -1) { return -1; // Type check error, not cached. } // NOTE: This is never decref'd as long as the object lives, which is likely // forever, but we don't want the type to get deleted as long as it is in // the map. This should not be too much of a leak, as there should only be a // relatively small number of types in the map, and an even smaller number // that are eligible for decref. As a precaution, we limit the size of the // map to 1024. { if (type_to_sequence_map_.size() < kMaxItemsInCache) { Py_INCREF(type); type_to_sequence_map_.insert({type, check_result}); } } return check_result; } private: std::function ternary_predicate_; std::unordered_map type_to_sequence_map_; }; py::object GetCollectionsSequenceType() { static py::object type = py::module::import("collections.abc").attr("Sequence"); return type; } py::object GetCollectionsMappingType() { static py::object type = py::module::import("collections.abc").attr("Mapping"); return type; } py::object GetCollectionsMappingViewType() { static py::object type = py::module::import("collections.abc").attr("MappingView"); return type; } py::object GetWraptObjectProxyTypeUncached() { try { return py::module::import("wrapt").attr("ObjectProxy"); } catch (const py::error_already_set& e) { if (e.matches(PyExc_ImportError)) return py::none(); throw e; } } py::object GetWraptObjectProxyType() { // TODO(gregthornton): Restore caching when deadlock issue is fixed. return GetWraptObjectProxyTypeUncached(); } // Returns 1 if `o` is considered a mapping for the purposes of Flatten(). // Returns 0 otherwise. // Returns -1 if an error occurred. int IsMappingHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { return PyObject_IsInstance(to_check, GetCollectionsMappingType().ptr()); }); if (PyDict_Check(o)) return true; return check_cache->CachedLookup(o); } // Returns 1 if `o` is considered a mapping view for the purposes of Flatten(). // Returns 0 otherwise. // Returns -1 if an error occurred. int IsMappingViewHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { return PyObject_IsInstance(to_check, GetCollectionsMappingViewType().ptr()); }); return check_cache->CachedLookup(o); } // Returns 1 if `o` is considered an object proxy // Returns 0 otherwise. // Returns -1 if an error occurred. int IsObjectProxy(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { auto type = GetWraptObjectProxyType(); return !type.is_none() && PyObject_IsInstance(to_check, type.ptr()) == 1; }); return check_cache->CachedLookup(o); } // Returns 1 if `o` is an instance of attrs-decorated class. // Returns 0 otherwise. int IsAttrsHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { PyObjectPtr cls(PyObject_GetAttrString(to_check, "__class__")); if (cls) { return PyObject_HasAttrString(cls.get(), "__attrs_attrs__"); } // PyObject_GetAttrString returns null on error PyErr_Clear(); return 0; }); return check_cache->CachedLookup(o); } // Returns 1 if `o` is considered a sequence for the purposes of Flatten(). // Returns 0 otherwise. // Returns -1 if an error occurred. int IsSequenceHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { int is_instance = PyObject_IsInstance(to_check, GetCollectionsSequenceType().ptr()); // Don't cache a failed is_instance check. if (is_instance == -1) return -1; return static_cast(is_instance != 0 && !IsString(to_check)); }); // We treat dicts and other mappings as special cases of sequences. if (IsMappingHelper(o)) return true; if (IsMappingViewHelper(o)) return true; if (IsAttrsHelper(o)) return true; if (PySet_Check(o) && !WarnedThatSetIsNotSequence) { LOG_WARNING( "Sets are not currently considered sequences, " "but this may change in the future, " "so consider avoiding using them."); WarnedThatSetIsNotSequence = true; } return check_cache->CachedLookup(o); } using ValueIteratorPtr = std::unique_ptr; // Iterate through dictionaries in a deterministic order by sorting the // keys. Notice this means that we ignore the original order of // `OrderedDict` instances. This is intentional, to avoid potential // bugs caused by mixing ordered and plain dicts (e.g., flattening // a dict but using a corresponding `OrderedDict` to pack it back). class DictValueIterator : public ValueIterator { public: explicit DictValueIterator(PyObject* dict) : dict_(dict), keys_(PyDict_Keys(dict)) { if (PyList_Sort(keys_.get()) == -1) { invalidate(); } else { iter_.reset(PyObject_GetIter(keys_.get())); } } PyObjectPtr next() override { PyObjectPtr result; PyObjectPtr key(PyIter_Next(iter_.get())); if (key) { // PyDict_GetItem returns a borrowed reference. PyObject* elem = PyDict_GetItem(dict_, key.get()); if (elem) { Py_INCREF(elem); result.reset(elem); } else { PyErr_SetString(PyExc_RuntimeError, "Dictionary was modified during iteration over it"); } } return result; } private: PyObject* dict_; PyObjectPtr keys_; PyObjectPtr iter_; }; // Iterate over mapping objects by sorting the keys first class MappingValueIterator : public ValueIterator { public: explicit MappingValueIterator(PyObject* mapping) : mapping_(mapping), keys_(PyMapping_Keys(mapping)) { if (!keys_ || PyList_Sort(keys_.get()) == -1) { invalidate(); } else { iter_.reset(PyObject_GetIter(keys_.get())); } } PyObjectPtr next() override { PyObjectPtr result; PyObjectPtr key(PyIter_Next(iter_.get())); if (key) { // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference. PyObject* elem = PyObject_GetItem(mapping_, key.get()); if (elem) { result.reset(elem); } else { PyErr_SetString(PyExc_RuntimeError, "Mapping was modified during iteration over it"); } } return result; } private: PyObject* mapping_; PyObjectPtr keys_; PyObjectPtr iter_; }; // Iterate over a sequence, by index. class SequenceValueIterator : public ValueIterator { public: explicit SequenceValueIterator(PyObject* iterable) : seq_(PySequence_Fast(iterable, "")), size_(seq_.get() ? PySequence_Fast_GET_SIZE(seq_.get()) : 0), index_(0) {} PyObjectPtr next() override { PyObjectPtr result; if (index_ < size_) { // PySequence_Fast_GET_ITEM returns a borrowed reference. PyObject* elem = PySequence_Fast_GET_ITEM(seq_.get(), index_); ++index_; if (elem) { Py_INCREF(elem); result.reset(elem); } } return result; } private: PyObjectPtr seq_; const Py_ssize_t size_; Py_ssize_t index_; }; class AttrsValueIterator : public ValueIterator { public: explicit AttrsValueIterator(PyObject* nested) : nested_(nested) { Py_INCREF(nested); cls_.reset(PyObject_GetAttrString(nested_.get(), "__class__")); if (cls_) { attrs_.reset(PyObject_GetAttrString(cls_.get(), "__attrs_attrs__")); if (attrs_) { iter_.reset(PyObject_GetIter(attrs_.get())); } } if (!iter_ || PyErr_Occurred()) invalidate(); } PyObjectPtr next() override { PyObjectPtr result; PyObjectPtr item(PyIter_Next(iter_.get())); if (item) { PyObjectPtr name(PyObject_GetAttrString(item.get(), "name")); result.reset(PyObject_GetAttr(nested_.get(), name.get())); } return result; } private: PyObjectPtr nested_; PyObjectPtr cls_; PyObjectPtr attrs_; PyObjectPtr iter_; }; bool FlattenHelper( PyObject* nested, PyObject* list, const std::function& is_sequence_helper, const std::function& value_iterator_getter) { // if nested is not a sequence, append itself and exit int is_seq = is_sequence_helper(nested); if (is_seq == -1) return false; if (!is_seq) { return PyList_Append(list, nested) != -1; } ValueIteratorPtr iter = value_iterator_getter(nested); if (!iter->valid()) return false; for (PyObjectPtr item = iter->next(); item; item = iter->next()) { if (Py_EnterRecursiveCall(" in flatten")) { return false; } const bool success = FlattenHelper(item.get(), list, is_sequence_helper, value_iterator_getter); Py_LeaveRecursiveCall(); if (!success) { return false; } } return true; } // Sets error using keys of 'dict1' and 'dict2'. // 'dict1' and 'dict2' are assumed to be Python dictionaries. void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, std::string* error_msg, bool* is_type_error) { PyObjectPtr k1(PyMapping_Keys(dict1)); if (PyErr_Occurred() || k1.get() == nullptr) { *error_msg = ("The two dictionaries don't have the same set of keys. Failed to " "fetch keys."); return; } PyObjectPtr k2(PyMapping_Keys(dict2)); if (PyErr_Occurred() || k2.get() == nullptr) { *error_msg = ("The two dictionaries don't have the same set of keys. Failed to " "fetch keys."); return; } *is_type_error = false; *error_msg = absl::StrCat( "The two dictionaries don't have the same set of keys. " "First structure has keys ", PyObjectToString(k1.get()), ", while second structure has keys ", PyObjectToString(k2.get())); } // Returns true iff there were no "internal" errors. In other words, // errors that has nothing to do with structure checking. // If an "internal" error occurred, the appropriate Python error will be // set and the caller can propage it directly to the user. // // Both `error_msg` and `is_type_error` must be non-null. `error_msg` must // be empty. // Leaves `error_msg` empty if structures matched. Else, fills `error_msg` // with appropriate error and sets `is_type_error` to true iff // the error to be raised should be TypeError. bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types, std::string* error_msg, bool* is_type_error) { DCHECK(error_msg); DCHECK(is_type_error); const bool is_seq1 = IsSequence(o1); const bool is_seq2 = IsSequence(o2); if (PyErr_Occurred()) return false; if (is_seq1 != is_seq2) { std::string seq_str = is_seq1 ? PyObjectToString(o1) : PyObjectToString(o2); std::string non_seq_str = is_seq1 ? PyObjectToString(o2) : PyObjectToString(o1); *is_type_error = false; *error_msg = absl::StrCat("Substructure \"", seq_str, "\" is a sequence, while substructure \"", non_seq_str, "\" is not"); return true; } // Got to scalars, so finished checking. Structures are the same. if (!is_seq1) return true; if (check_types) { // Unwrap wrapt.ObjectProxy if needed. PyObjectPtr o1_wrapped; if (IsObjectProxy(o1)) { o1_wrapped.reset(PyObject_GetAttrString(o1, "__wrapped__")); o1 = o1_wrapped.get(); } PyObjectPtr o2_wrapped; if (IsObjectProxy(o2)) { o2_wrapped.reset(PyObject_GetAttrString(o2, "__wrapped__")); o2 = o2_wrapped.get(); } const PyTypeObject* type1 = o1->ob_type; const PyTypeObject* type2 = o2->ob_type; // We treat two different namedtuples with identical name and fields // as having the same type. const PyObject* o1_tuple = IsNamedtuple(o1, true); if (o1_tuple == nullptr) return false; const PyObject* o2_tuple = IsNamedtuple(o2, true); if (o2_tuple == nullptr) { Py_DECREF(o1_tuple); return false; } bool both_tuples = o1_tuple == Py_True && o2_tuple == Py_True; Py_DECREF(o1_tuple); Py_DECREF(o2_tuple); if (both_tuples) { const PyObject* same_tuples = SameNamedtuples(o1, o2); if (same_tuples == nullptr) return false; bool not_same_tuples = same_tuples != Py_True; Py_DECREF(same_tuples); if (not_same_tuples) { *is_type_error = true; *error_msg = absl::StrCat( "The two namedtuples don't have the same sequence type. " "First structure ", PyObjectToString(o1), " has type ", type1->tp_name, ", while second structure ", PyObjectToString(o2), " has type ", type2->tp_name); return true; } } else if (type1 != type2 /* If both sequences are list types, don't complain. This allows one to be a list subclass (e.g. _ListWrapper used for automatic dependency tracking.) */ && !(PyList_Check(o1) && PyList_Check(o2)) /* Two mapping types will also compare equal, making _DictWrapper and dict compare equal. */ && !(IsMappingHelper(o1) && IsMappingHelper(o2))) { *is_type_error = true; *error_msg = absl::StrCat( "The two namedtuples don't have the same sequence type. " "First structure ", PyObjectToString(o1), " has type ", type1->tp_name, ", while second structure ", PyObjectToString(o2), " has type ", type2->tp_name); return true; } if (PyDict_Check(o1) && PyDict_Check(o2)) { if (PyDict_Size(o1) != PyDict_Size(o2)) { SetDifferentKeysError(o1, o2, error_msg, is_type_error); return true; } PyObject* key; Py_ssize_t pos = 0; while (PyDict_Next(o1, &pos, &key, nullptr)) { if (PyDict_GetItem(o2, key) == nullptr) { SetDifferentKeysError(o1, o2, error_msg, is_type_error); return true; } } } else if (IsMappingHelper(o1)) { // Fallback for custom mapping types. Instead of using PyDict methods // which stay in C, we call iter(o1). if (PyMapping_Size(o1) != PyMapping_Size(o2)) { SetDifferentKeysError(o1, o2, error_msg, is_type_error); return true; } PyObjectPtr iter(PyObject_GetIter(o1)); PyObject* key; while ((key = PyIter_Next(iter.get())) != nullptr) { if (!PyMapping_HasKey(o2, key)) { SetDifferentKeysError(o1, o2, error_msg, is_type_error); Py_DECREF(key); return true; } Py_DECREF(key); } } } ValueIteratorPtr iter1 = GetValueIterator(o1); ValueIteratorPtr iter2 = GetValueIterator(o2); if (!iter1->valid() || !iter2->valid()) return false; while (true) { PyObjectPtr v1 = iter1->next(); PyObjectPtr v2 = iter2->next(); if (v1 && v2) { if (Py_EnterRecursiveCall(" in assert_same_structure")) { return false; } bool no_internal_errors = AssertSameStructureHelper( v1.get(), v2.get(), check_types, error_msg, is_type_error); Py_LeaveRecursiveCall(); if (!no_internal_errors) return false; if (!error_msg->empty()) return true; } else if (!v1 && !v2) { // Done with all recursive calls. Structure matched. return true; } else { *is_type_error = false; *error_msg = absl::StrCat( "The two structures don't have the same number of elements. ", "First structure: ", PyObjectToString(o1), ". Second structure: ", PyObjectToString(o2)); return true; } } } } // namespace bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; } bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; } PyObject* Flatten(PyObject* nested) { PyObject* list = PyList_New(0); if (FlattenHelper(nested, list, IsSequenceHelper, GetValueIterator)) { return list; } else { Py_DECREF(list); return nullptr; } } PyObject* IsNamedtuple(PyObject* o, bool strict) { // Unwrap wrapt.ObjectProxy if needed. PyObjectPtr o_wrapped; if (IsObjectProxy(o)) { o_wrapped.reset(PyObject_GetAttrString(o, "__wrapped__")); o = o_wrapped.get(); } // Must be subclass of tuple if (!PyTuple_Check(o)) { Py_RETURN_FALSE; } // If strict, o.__class__.__base__ must be tuple if (strict) { PyObject* klass = PyObject_GetAttrString(o, "__class__"); if (klass == nullptr) return nullptr; PyObject* base = PyObject_GetAttrString(klass, "__base__"); Py_DECREF(klass); if (base == nullptr) return nullptr; const PyTypeObject* base_type = reinterpret_cast(base); // built-in object types are singletons bool tuple_base = base_type == &PyTuple_Type; Py_DECREF(base); if (!tuple_base) { Py_RETURN_FALSE; } } // o must have attribute '_fields' and every element in // '_fields' must be a string. int has_fields = PyObject_HasAttrString(o, "_fields"); if (!has_fields) { Py_RETURN_FALSE; } PyObjectPtr fields(PyObject_GetAttrString(o, "_fields")); int is_instance = PyObject_IsInstance(fields.get(), GetCollectionsSequenceType().ptr()); if (is_instance == 0) { Py_RETURN_FALSE; } else if (is_instance == -1) { return nullptr; } PyObjectPtr seq(PySequence_Fast(fields.get(), "")); const Py_ssize_t s = PySequence_Fast_GET_SIZE(seq.get()); for (Py_ssize_t i = 0; i < s; ++i) { // PySequence_Fast_GET_ITEM returns borrowed ref PyObject* elem = PySequence_Fast_GET_ITEM(seq.get(), i); if (!IsString(elem)) { Py_RETURN_FALSE; } } Py_RETURN_TRUE; } PyObject* SameNamedtuples(PyObject* o1, PyObject* o2) { PyObject* f1 = PyObject_GetAttrString(o1, "_fields"); PyObject* f2 = PyObject_GetAttrString(o2, "_fields"); if (f1 == nullptr || f2 == nullptr) { Py_XDECREF(f1); Py_XDECREF(f2); PyErr_SetString( PyExc_RuntimeError, "Expected namedtuple-like objects (that have _fields attr)"); return nullptr; } if (PyObject_RichCompareBool(f1, f2, Py_NE)) { Py_RETURN_FALSE; } if (GetClassName(o1).compare(GetClassName(o2)) == 0) { Py_RETURN_TRUE; } else { Py_RETURN_FALSE; } } void AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types) { std::string error_msg; bool is_type_error = false; AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error); if (PyErr_Occurred()) { // Don't hide Python exceptions while checking (e.g. errors fetching keys // from custom mappings). return; } if (!error_msg.empty()) { PyErr_SetString( is_type_error ? PyExc_TypeError : PyExc_ValueError, absl::StrCat( "The two structures don't have the same nested structure.\n\n", "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ", PyObjectToString(o2), "\n\nMore specifically: ", error_msg) .c_str()); } } ValueIteratorPtr GetValueIterator(PyObject* nested) { if (PyDict_Check(nested)) { return absl::make_unique(nested); } else if (IsMappingHelper(nested)) { return absl::make_unique(nested); } else if (IsAttrsHelper(nested)) { return absl::make_unique(nested); } else { return absl::make_unique(nested); } } namespace { inline py::object pyo_or_throw(PyObject* ptr) { if (PyErr_Occurred() || ptr == nullptr) { throw py::error_already_set(); } return py::reinterpret_steal(ptr); } PYBIND11_MODULE(_tree, m) { // Resolve `wrapt.ObjectProxy` at import time to avoid doing // imports during function calls. tree::GetWraptObjectProxyType(); m.def("assert_same_structure", [](py::handle& o1, py::handle& o2, bool check_types) { tree::AssertSameStructure(o1.ptr(), o2.ptr(), check_types); if (PyErr_Occurred()) { throw py::error_already_set(); } }); m.def("is_sequence", [](py::handle& o) { bool result = tree::IsSequence(o.ptr()); if (PyErr_Occurred()) { throw py::error_already_set(); } return result; }); m.def("is_namedtuple", [](py::handle& o, bool strict) { return pyo_or_throw(tree::IsNamedtuple(o.ptr(), strict)); }); m.def("is_attrs", [](py::handle& o) { bool result = tree::IsAttrs(o.ptr()); if (PyErr_Occurred()) { throw py::error_already_set(); } return result; }); m.def("same_namedtuples", [](py::handle& o1, py::handle& o2) { return pyo_or_throw(tree::SameNamedtuples(o1.ptr(), o2.ptr())); }); m.def("flatten", [](py::handle& nested) { return pyo_or_throw(tree::Flatten(nested.ptr())); }); } } // namespace } // namespace tree tree-0.1.9/tree/tree.h000066400000000000000000000112731474673237200145520ustar00rootroot00000000000000/* Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef TREE_H_ #define TREE_H_ #include #include namespace tree { // Returns a true if its input is a collections.Sequence (except strings). // // Args: // seq: an input sequence. // // Returns: // True if the sequence is a not a string and is a collections.Sequence or a // dict. bool IsSequence(PyObject* o); // Returns Py_True iff `instance` should be considered a `namedtuple`. // // Args: // instance: An instance of a Python object. // strict: If True, `instance` is considered to be a `namedtuple` only if // it is a "plain" namedtuple. For instance, a class inheriting // from a `namedtuple` will be considered to be a `namedtuple` // iff `strict=False`. // // Returns: // True if `instance` is a `namedtuple`. PyObject* IsNamedtuple(PyObject* o, bool strict); // Returns a true if its input is an instance of an attr.s decorated class. // // Args: // o: the input to be checked. // // Returns: // True if the object is an instance of an attr.s decorated class. bool IsAttrs(PyObject* o); // Returns Py_True iff the two namedtuples have the same name and fields. // Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have // '_fields' attribute). PyObject* SameNamedtuples(PyObject* o1, PyObject* o2); // Asserts that two structures are nested in the same way. // // Note that namedtuples with identical name and fields are always considered // to have the same shallow structure (even with `check_types=True`). // For intance, this code will print `True`: // // ```python // def nt(a, b): // return collections.namedtuple('foo', 'a b')(a, b) // print(assert_same_structure(nt(0, 1), nt(2, 3))) // ``` // // Args: // nest1: an arbitrarily nested structure. // nest2: an arbitrarily nested structure. // check_types: if `true`, types of sequences are checked as // well, including the keys of dictionaries. If set to `false`, for example // a list and a tuple of objects will look the same if they have the same // size. Note that namedtuples with identical name and fields are always // considered to have the same shallow structure. // // Raises: // ValueError: If the two structures do not have the same number of elements or // if the two structures are not nested in the same way. // TypeError: If the two structures differ in the type of sequence in any of // their substructures. Only possible if `check_types` is `True`. void AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types); // // Returns a flat list from a given nested structure. // // If `nest` is not a sequence, tuple, or dict, then returns a single-element // list: `[nest]`. // // In the case of dict instances, the sequence consists of the values, sorted by // key to ensure deterministic behavior. This is true also for `OrderedDict` // instances: their sequence order is ignored, the sorting order of keys is // used instead. The same convention is followed in `pack_sequence_as`. This // correctly repacks dicts and `OrderedDict`s after they have been flattened, // and also allows flattening an `OrderedDict` and then repacking it back using // a corresponding plain dict, or vice-versa. // Dictionaries with non-sortable keys cannot be flattened. // // Args: // nest: an arbitrarily nested structure or a scalar object. Note, numpy // arrays are considered scalars. // // Returns: // A Python list, the flattened version of the input. // On error, returns nullptr // // Raises: // TypeError: The nest is or contains a dict with non-sortable keys. PyObject* Flatten(PyObject* nested); struct DecrementsPyRefcount { void operator()(PyObject* p) const { Py_DECREF(p); } }; // ValueIterator interface class ValueIterator { public: virtual ~ValueIterator() {} virtual std::unique_ptr next() = 0; bool valid() const { return is_valid_; } protected: void invalidate() { is_valid_ = false; } private: bool is_valid_ = true; }; std::unique_ptr GetValueIterator(PyObject* nested); } // namespace tree #endif // TREE_H_ tree-0.1.9/tree/tree_benchmark.py000066400000000000000000000037041474673237200167650ustar00rootroot00000000000000# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Benchmarks for utilities working with arbitrarily nested structures.""" import collections import timeit import tree TIME_UNITS = [ (1, "s"), (10**-3, "ms"), (10**-6, "us"), (10**-9, "ns"), ] def format_time(time): for d, unit in TIME_UNITS: if time > d: return "{:.2f}{}".format(time / d, unit) def run_benchmark(benchmark_fn, num_iters): times = timeit.repeat(benchmark_fn, repeat=2, number=num_iters) return times[-1] / num_iters # Discard the first half for "warmup". def map_to_list(func, *args): return list(map(func, *args)) def benchmark_map(map_fn, structure): def benchmark_fn(): return map_fn(lambda v: v, structure) return benchmark_fn BENCHMARKS = collections.OrderedDict([ ("tree_map_1", benchmark_map(tree.map_structure, [0])), ("tree_map_8", benchmark_map(tree.map_structure, [0] * 8)), ("tree_map_64", benchmark_map(tree.map_structure, [0] * 64)), ("builtin_map_1", benchmark_map(map_to_list, [0])), ("builtin_map_8", benchmark_map(map_to_list, [0] * 8)), ("builtin_map_64", benchmark_map(map_to_list, [0] * 64)), ]) def main(): for name, benchmark_fn in BENCHMARKS.items(): print(name, format_time(run_benchmark(benchmark_fn, num_iters=1000))) if __name__ == "__main__": main() tree-0.1.9/tree/tree_test.py000066400000000000000000001357221474673237200160200ustar00rootroot00000000000000# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for utilities working with arbitrarily nested structures.""" import collections import doctest import types from typing import Any, Iterator, Mapping import unittest from absl.testing import parameterized import attr import numpy as np import tree import wrapt STRUCTURE1 = (((1, 2), 3), 4, (5, 6)) STRUCTURE2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) STRUCTURE_DIFFERENT_NUM_ELEMENTS = ("spam", "eggs") STRUCTURE_DIFFERENT_NESTING = (((1, 2), 3), 4, 5, (6,)) class DoctestTest(parameterized.TestCase): def testDoctest(self): extraglobs = { "collections": collections, "tree": tree, } num_failed, num_attempted = doctest.testmod( tree, extraglobs=extraglobs, optionflags=doctest.ELLIPSIS) self.assertGreater(num_attempted, 0, "No doctests found.") self.assertEqual(num_failed, 0, "{} doctests failed".format(num_failed)) class NestTest(parameterized.TestCase): def assertAllEquals(self, a, b): self.assertTrue((np.asarray(a) == b).all()) def testAttrsFlattenAndUnflatten(self): class BadAttr(object): """Class that has a non-iterable __attrs_attrs__.""" __attrs_attrs__ = None @attr.s class SampleAttr(object): field1 = attr.ib() field2 = attr.ib() field_values = [1, 2] sample_attr = SampleAttr(*field_values) self.assertFalse(tree._is_attrs(field_values)) self.assertTrue(tree._is_attrs(sample_attr)) flat = tree.flatten(sample_attr) self.assertEqual(field_values, flat) restructured_from_flat = tree.unflatten_as(sample_attr, flat) self.assertIsInstance(restructured_from_flat, SampleAttr) self.assertEqual(restructured_from_flat, sample_attr) # Check that flatten fails if attributes are not iterable with self.assertRaisesRegex(TypeError, "object is not iterable"): flat = tree.flatten(BadAttr()) @parameterized.parameters([ (1, 2, 3), ({"B": 10, "A": 20}, [1, 2], 3), ((1, 2), [3, 4], 5), (collections.namedtuple("Point", ["x", "y"])(1, 2), 3, 4), wrapt.ObjectProxy( (collections.namedtuple("Point", ["x", "y"])(1, 2), 3, 4)) ]) def testAttrsMapStructure(self, *field_values): @attr.s class SampleAttr(object): field3 = attr.ib() field1 = attr.ib() field2 = attr.ib() structure = SampleAttr(*field_values) new_structure = tree.map_structure(lambda x: x, structure) self.assertEqual(structure, new_structure) def testFlattenAndUnflatten(self): structure = ((3, 4), 5, (6, 7, (9, 10), 8)) flat = ["a", "b", "c", "d", "e", "f", "g", "h"] self.assertEqual(tree.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]) self.assertEqual( tree.unflatten_as(structure, flat), (("a", "b"), "c", ("d", "e", ("f", "g"), "h"))) point = collections.namedtuple("Point", ["x", "y"]) structure = (point(x=4, y=2), ((point(x=1, y=0),),)) flat = [4, 2, 1, 0] self.assertEqual(tree.flatten(structure), flat) restructured_from_flat = tree.unflatten_as(structure, flat) self.assertEqual(restructured_from_flat, structure) self.assertEqual(restructured_from_flat[0].x, 4) self.assertEqual(restructured_from_flat[0].y, 2) self.assertEqual(restructured_from_flat[1][0][0].x, 1) self.assertEqual(restructured_from_flat[1][0][0].y, 0) self.assertEqual([5], tree.flatten(5)) self.assertEqual([np.array([5])], tree.flatten(np.array([5]))) self.assertEqual("a", tree.unflatten_as(5, ["a"])) self.assertEqual( np.array([5]), tree.unflatten_as("scalar", [np.array([5])])) with self.assertRaisesRegex(ValueError, "Structure is a scalar"): tree.unflatten_as("scalar", [4, 5]) with self.assertRaisesRegex(TypeError, "flat_sequence"): tree.unflatten_as([4, 5], "bad_sequence") with self.assertRaises(ValueError): tree.unflatten_as([5, 6, [7, 8]], ["a", "b", "c"]) def testFlattenDictOrder(self): ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) plain = {"d": 3, "b": 1, "a": 0, "c": 2} ordered_flat = tree.flatten(ordered) plain_flat = tree.flatten(plain) self.assertEqual([0, 1, 2, 3], ordered_flat) self.assertEqual([0, 1, 2, 3], plain_flat) def testUnflattenDictOrder(self): ordered = collections.OrderedDict([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) plain = {"d": 0, "b": 0, "a": 0, "c": 0} seq = [0, 1, 2, 3] ordered_reconstruction = tree.unflatten_as(ordered, seq) plain_reconstruction = tree.unflatten_as(plain, seq) self.assertEqual( collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), ordered_reconstruction) self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) def testFlattenAndUnflatten_withDicts(self): # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. named_tuple = collections.namedtuple("A", ("b", "c")) mess = [ "z", named_tuple(3, 4), { "c": [ 1, collections.OrderedDict([ ("b", 3), ("a", 2), ]), ], "b": 5 }, 17 ] flattened = tree.flatten(mess) self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17]) structure_of_mess = [ 14, named_tuple("a", True), { "c": [ 0, collections.OrderedDict([ ("b", 9), ("a", 8), ]), ], "b": 3 }, "hi everybody", ] self.assertEqual(mess, tree.unflatten_as(structure_of_mess, flattened)) # Check also that the OrderedDict was created, with the correct key order. unflattened_ordered_dict = tree.unflatten_as( structure_of_mess, flattened)[2]["c"][1] self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) def testFlatten_numpyIsNotFlattened(self): structure = np.array([1, 2, 3]) flattened = tree.flatten(structure) self.assertLen(flattened, 1) def testFlatten_stringIsNotFlattened(self): structure = "lots of letters" flattened = tree.flatten(structure) self.assertLen(flattened, 1) self.assertEqual(structure, tree.unflatten_as("goodbye", flattened)) def testFlatten_bytearrayIsNotFlattened(self): structure = bytearray("bytes in an array", "ascii") flattened = tree.flatten(structure) self.assertLen(flattened, 1) self.assertEqual(flattened, [structure]) self.assertEqual(structure, tree.unflatten_as(bytearray("hello", "ascii"), flattened)) def testUnflattenSequenceAs_notIterableError(self): with self.assertRaisesRegex(TypeError, "flat_sequence must be a sequence"): tree.unflatten_as("hi", "bye") def testUnflattenSequenceAs_wrongLengthsError(self): with self.assertRaisesRegex( ValueError, "Structure had 2 elements, but flat_sequence had 3 elements."): tree.unflatten_as(["hello", "world"], ["and", "goodbye", "again"]) def testUnflattenSequenceAs_defaultdict(self): structure = collections.defaultdict( list, [("a", [None]), ("b", [None, None])]) sequence = [1, 2, 3] expected = collections.defaultdict( list, [("a", [1]), ("b", [2, 3])]) self.assertEqual(expected, tree.unflatten_as(structure, sequence)) def testIsSequence(self): self.assertFalse(tree.is_nested("1234")) self.assertFalse(tree.is_nested(b"1234")) self.assertFalse(tree.is_nested(u"1234")) self.assertFalse(tree.is_nested(bytearray("1234", "ascii"))) self.assertTrue(tree.is_nested([1, 3, [4, 5]])) self.assertTrue(tree.is_nested(((7, 8), (5, 6)))) self.assertTrue(tree.is_nested([])) self.assertTrue(tree.is_nested({"a": 1, "b": 2})) self.assertFalse(tree.is_nested(set([1, 2]))) ones = np.ones([2, 3]) self.assertFalse(tree.is_nested(ones)) self.assertFalse(tree.is_nested(np.tanh(ones))) self.assertFalse(tree.is_nested(np.ones((4, 5)))) # pylint does not correctly recognize these as class names and # suggests to use variable style under_score naming. # pylint: disable=invalid-name Named0ab = collections.namedtuple("named_0", ("a", "b")) Named1ab = collections.namedtuple("named_1", ("a", "b")) SameNameab = collections.namedtuple("same_name", ("a", "b")) SameNameab2 = collections.namedtuple("same_name", ("a", "b")) SameNamexy = collections.namedtuple("same_name", ("x", "y")) SameName1xy = collections.namedtuple("same_name_1", ("x", "y")) SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y")) NotSameName = collections.namedtuple("not_same_name", ("a", "b")) # pylint: enable=invalid-name class SameNamedType1(SameNameab): pass # pylint: disable=g-error-prone-assert-raises def testAssertSameStructure(self): tree.assert_same_structure(STRUCTURE1, STRUCTURE2) tree.assert_same_structure("abc", 1.0) tree.assert_same_structure(b"abc", 1.0) tree.assert_same_structure(u"abc", 1.0) tree.assert_same_structure(bytearray("abc", "ascii"), 1.0) tree.assert_same_structure("abc", np.array([0, 1])) def testAssertSameStructure_differentNumElements(self): with self.assertRaisesRegex( ValueError, ("The two structures don't have the same nested structure\\.\n\n" "First structure:.*?\n\n" "Second structure:.*\n\n" "More specifically: Substructure " r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while ' 'substructure "type=str str=spam" is not\n' "Entire first structure:\n" r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n" "Entire second structure:\n" r"\(\., \.\)")): tree.assert_same_structure(STRUCTURE1, STRUCTURE_DIFFERENT_NUM_ELEMENTS) def testAssertSameStructure_listVsNdArray(self): with self.assertRaisesRegex( ValueError, ("The two structures don't have the same nested structure\\.\n\n" "First structure:.*?\n\n" "Second structure:.*\n\n" r'More specifically: Substructure "type=list str=\[0, 1\]" ' r'is a sequence, while substructure "type=ndarray str=\[0 1\]" ' "is not")): tree.assert_same_structure([0, 1], np.array([0, 1])) def testAssertSameStructure_intVsList(self): with self.assertRaisesRegex( ValueError, ("The two structures don't have the same nested structure\\.\n\n" "First structure:.*?\n\n" "Second structure:.*\n\n" r'More specifically: Substructure "type=list str=\[0, 1\]" ' 'is a sequence, while substructure "type=int str=0" ' "is not")): tree.assert_same_structure(0, [0, 1]) def testAssertSameStructure_tupleVsList(self): self.assertRaises( TypeError, tree.assert_same_structure, (0, 1), [0, 1]) def testAssertSameStructure_differentNesting(self): with self.assertRaisesRegex( ValueError, ("don't have the same nested structure\\.\n\n" "First structure: .*?\n\nSecond structure: ")): tree.assert_same_structure(STRUCTURE1, STRUCTURE_DIFFERENT_NESTING) def testAssertSameStructure_tupleVsNamedTuple(self): self.assertRaises(TypeError, tree.assert_same_structure, (0, 1), NestTest.Named0ab("a", "b")) def testAssertSameStructure_sameNamedTupleDifferentContents(self): tree.assert_same_structure(NestTest.Named0ab(3, 4), NestTest.Named0ab("a", "b")) def testAssertSameStructure_differentNamedTuples(self): self.assertRaises(TypeError, tree.assert_same_structure, NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4)) def testAssertSameStructure_sameNamedTupleDifferentStructuredContents(self): with self.assertRaisesRegex( ValueError, ("don't have the same nested structure\\.\n\n" "First structure: .*?\n\nSecond structure: ")): tree.assert_same_structure(NestTest.Named0ab(3, 4), NestTest.Named0ab([3], 4)) def testAssertSameStructure_differentlyNestedLists(self): with self.assertRaisesRegex( ValueError, ("don't have the same nested structure\\.\n\n" "First structure: .*?\n\nSecond structure: ")): tree.assert_same_structure([[3], 4], [3, [4]]) def testAssertSameStructure_listStructureWithAndWithoutTypes(self): structure1_list = [[[1, 2], 3], 4, [5, 6]] with self.assertRaisesRegex(TypeError, "don't have the same sequence type"): tree.assert_same_structure(STRUCTURE1, structure1_list) tree.assert_same_structure(STRUCTURE1, STRUCTURE2, check_types=False) tree.assert_same_structure(STRUCTURE1, structure1_list, check_types=False) def testAssertSameStructure_dictionaryDifferentKeys(self): with self.assertRaisesRegex(ValueError, "don't have the same set of keys"): tree.assert_same_structure({"a": 1}, {"b": 1}) def testAssertSameStructure_sameNameNamedTuples(self): tree.assert_same_structure(NestTest.SameNameab(0, 1), NestTest.SameNameab2(2, 3)) def testAssertSameStructure_sameNameNamedTuplesNested(self): # This assertion is expected to pass: two namedtuples with the same # name and field names are considered to be identical. tree.assert_same_structure( NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2), NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4)) def testAssertSameStructure_sameNameNamedTuplesDifferentStructure(self): expected_message = "The two structures don't have the same.*" with self.assertRaisesRegex(ValueError, expected_message): tree.assert_same_structure( NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)), NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2)) def testAssertSameStructure_differentNameNamedStructures(self): self.assertRaises(TypeError, tree.assert_same_structure, NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3)) def testAssertSameStructure_sameNameDifferentFieldNames(self): self.assertRaises(TypeError, tree.assert_same_structure, NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3)) def testAssertSameStructure_classWrappingNamedTuple(self): self.assertRaises(TypeError, tree.assert_same_structure, NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3)) # pylint: enable=g-error-prone-assert-raises def testMapStructure(self): structure2 = (((7, 8), 9), 10, (11, 12)) structure1_plus1 = tree.map_structure(lambda x: x + 1, STRUCTURE1) tree.assert_same_structure(STRUCTURE1, structure1_plus1) self.assertAllEquals( [2, 3, 4, 5, 6, 7], tree.flatten(structure1_plus1)) structure1_plus_structure2 = tree.map_structure( lambda x, y: x + y, STRUCTURE1, structure2) self.assertEqual( (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), structure1_plus_structure2) self.assertEqual(3, tree.map_structure(lambda x: x - 1, 4)) self.assertEqual(7, tree.map_structure(lambda x, y: x + y, 3, 4)) # Empty structures self.assertEqual((), tree.map_structure(lambda x: x + 1, ())) self.assertEqual([], tree.map_structure(lambda x: x + 1, [])) self.assertEqual({}, tree.map_structure(lambda x: x + 1, {})) empty_nt = collections.namedtuple("empty_nt", "") self.assertEqual(empty_nt(), tree.map_structure(lambda x: x + 1, empty_nt())) # This is checking actual equality of types, empty list != empty tuple self.assertNotEqual((), tree.map_structure(lambda x: x + 1, [])) with self.assertRaisesRegex(TypeError, "callable"): tree.map_structure("bad", structure1_plus1) with self.assertRaisesRegex(ValueError, "at least one structure"): tree.map_structure(lambda x: x) with self.assertRaisesRegex(ValueError, "same number of elements"): tree.map_structure(lambda x, y: None, (3, 4), (3, 4, 5)) with self.assertRaisesRegex(ValueError, "same nested structure"): tree.map_structure(lambda x, y: None, 3, (3,)) with self.assertRaisesRegex(TypeError, "same sequence type"): tree.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) with self.assertRaisesRegex(ValueError, "same nested structure"): tree.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) structure1_list = [[[1, 2], 3], 4, [5, 6]] with self.assertRaisesRegex(TypeError, "same sequence type"): tree.map_structure(lambda x, y: None, STRUCTURE1, structure1_list) tree.map_structure(lambda x, y: None, STRUCTURE1, structure1_list, check_types=False) with self.assertRaisesRegex(ValueError, "same nested structure"): tree.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), check_types=False) with self.assertRaisesRegex(ValueError, "Only valid keyword argument.*foo"): tree.map_structure(lambda x: None, STRUCTURE1, foo="a") with self.assertRaisesRegex(ValueError, "Only valid keyword argument.*foo"): tree.map_structure(lambda x: None, STRUCTURE1, check_types=False, foo="a") def testMapStructureWithStrings(self): ab_tuple = collections.namedtuple("ab_tuple", "a, b") inp_a = ab_tuple(a="foo", b=("bar", "baz")) inp_b = ab_tuple(a=2, b=(1, 3)) out = tree.map_structure(lambda string, repeats: string * repeats, inp_a, inp_b) self.assertEqual("foofoo", out.a) self.assertEqual("bar", out.b[0]) self.assertEqual("bazbazbaz", out.b[1]) nt = ab_tuple(a=("something", "something_else"), b="yet another thing") rev_nt = tree.map_structure(lambda x: x[::-1], nt) # Check the output is the correct structure, and all strings are reversed. tree.assert_same_structure(nt, rev_nt) self.assertEqual(nt.a[0][::-1], rev_nt.a[0]) self.assertEqual(nt.a[1][::-1], rev_nt.a[1]) self.assertEqual(nt.b[::-1], rev_nt.b) def testAssertShallowStructure(self): inp_ab = ["a", "b"] inp_abc = ["a", "b", "c"] with self.assertRaisesRegex( ValueError, tree._STRUCTURES_HAVE_MISMATCHING_LENGTHS.format( input_length=len(inp_ab), shallow_length=len(inp_abc))): tree._assert_shallow_structure(inp_abc, inp_ab) inp_ab1 = [(1, 1), (2, 2)] inp_ab2 = [[1, 1], [2, 2]] with self.assertRaisesWithLiteralMatch( TypeError, tree._STRUCTURES_HAVE_MISMATCHING_TYPES.format( shallow_type=type(inp_ab2[0]), input_type=type(inp_ab1[0]))): tree._assert_shallow_structure(shallow_tree=inp_ab2, input_tree=inp_ab1) tree._assert_shallow_structure(inp_ab2, inp_ab1, check_types=False) inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} with self.assertRaisesWithLiteralMatch( ValueError, tree._SHALLOW_TREE_HAS_INVALID_KEYS.format(["d"])): tree._assert_shallow_structure(inp_ab2, inp_ab1) inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) tree._assert_shallow_structure(inp_ab, inp_ba) # regression test for b/130633904 tree._assert_shallow_structure({0: "foo"}, ["foo"], check_types=False) def testFlattenUpTo(self): # Shallow tree ends at scalar. input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] shallow_tree = [[True, True], [False, True]] flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) self.assertEqual(flattened_shallow_tree, [True, True, False, True]) # Shallow tree ends at string. input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree, input_tree) input_tree_flattened = tree.flatten(input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) # Make sure dicts are correctly flattened, yielding values, not keys. input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [1, {"c": 2}, 3, (4, 5)]) # Namedtuples. ab_tuple = collections.namedtuple("ab_tuple", "a, b") input_tree = ab_tuple(a=[0, 1], b=2) shallow_tree = ab_tuple(a=0, b=1) input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [[0, 1], 2]) # Attrs. @attr.s class ABAttr(object): a = attr.ib() b = attr.ib() input_tree = ABAttr(a=[0, 1], b=2) shallow_tree = ABAttr(a=0, b=1) input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [[0, 1], 2]) # Nested dicts, OrderedDicts and namedtuples. input_tree = collections.OrderedDict( [("a", ab_tuple(a=[0, {"b": 1}], b=2)), ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) shallow_tree = input_tree input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [ab_tuple(a=[0, {"b": 1}], b=2), 3, collections.OrderedDict([("f", 4)])]) shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) input_tree_flattened_as_shallow_tree = tree.flatten_up_to(shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree, [ab_tuple(a=[0, {"b": 1}], b=2), {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) ## Shallow non-list edge-case. # Using iterable elements. input_tree = ["input_tree"] shallow_tree = "shallow_tree" flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) input_tree = ["input_tree_0", "input_tree_1"] shallow_tree = "shallow_tree" flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) # Using non-iterable elements. input_tree = [0] shallow_tree = 9 flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) input_tree = [0, 1] shallow_tree = 9 flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) ## Both non-list edge-case. # Using iterable elements. input_tree = "input_tree" shallow_tree = "shallow_tree" flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) # Using non-iterable elements. input_tree = 0 shallow_tree = 0 flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) ## Input non-list edge-case. # Using iterable elements. input_tree = "input_tree" shallow_tree = ["shallow_tree"] with self.assertRaisesWithLiteralMatch( TypeError, tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))): flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree) input_tree = "input_tree" shallow_tree = ["shallow_tree_9", "shallow_tree_8"] with self.assertRaisesWithLiteralMatch( TypeError, tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))): flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree) # Using non-iterable elements. input_tree = 0 shallow_tree = [9] with self.assertRaisesWithLiteralMatch( TypeError, tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))): flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree) input_tree = 0 shallow_tree = [9, 8] with self.assertRaisesWithLiteralMatch( TypeError, tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))): flattened_input_tree = tree.flatten_up_to(shallow_tree, input_tree) flattened_shallow_tree = tree.flatten_up_to(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree, shallow_tree) def testByteStringsNotTreatedAsIterable(self): structure = [u"unicode string", b"byte string"] flattened_structure = tree.flatten_up_to(structure, structure) self.assertEqual(structure, flattened_structure) def testFlattenWithPathUpTo(self): def get_paths_and_values(shallow_tree, input_tree): path_value_pairs = tree.flatten_with_path_up_to(shallow_tree, input_tree) paths = [p for p, _ in path_value_pairs] values = [v for _, v in path_value_pairs] return paths, values # Shallow tree ends at scalar. input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] shallow_tree = [[True, True], [False, True]] (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree_paths, [(0, 0), (0, 1), (1, 0), (1, 1)]) self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) self.assertEqual(flattened_shallow_tree_paths, [(0, 0), (0, 1), (1, 0), (1, 1)]) self.assertEqual(flattened_shallow_tree, [True, True, False, True]) # Shallow tree ends at string. input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] (input_tree_flattened_as_shallow_tree_paths, input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree, input_tree) input_tree_flattened_paths = [ p for p, _ in tree.flatten_with_path(input_tree) ] input_tree_flattened = tree.flatten(input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree_paths, [(0, 0), (0, 1, 0), (0, 1, 1, 0), (0, 1, 1, 1, 0)]) self.assertEqual(input_tree_flattened_as_shallow_tree, [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) self.assertEqual(input_tree_flattened_paths, [(0, 0, 0), (0, 0, 1), (0, 1, 0, 0), (0, 1, 0, 1), (0, 1, 1, 0, 0), (0, 1, 1, 0, 1), (0, 1, 1, 1, 0, 0), (0, 1, 1, 1, 0, 1)]) self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) # Make sure dicts are correctly flattened, yielding values, not keys. input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} (input_tree_flattened_as_shallow_tree_paths, input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree_paths, [("a",), ("b",), ("d", 0), ("d", 1)]) self.assertEqual(input_tree_flattened_as_shallow_tree, [1, {"c": 2}, 3, (4, 5)]) # Namedtuples. ab_tuple = collections.namedtuple("ab_tuple", "a, b") input_tree = ab_tuple(a=[0, 1], b=2) shallow_tree = ab_tuple(a=0, b=1) (input_tree_flattened_as_shallow_tree_paths, input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree_paths, [("a",), ("b",)]) self.assertEqual(input_tree_flattened_as_shallow_tree, [[0, 1], 2]) # Nested dicts, OrderedDicts and namedtuples. input_tree = collections.OrderedDict( [("a", ab_tuple(a=[0, {"b": 1}], b=2)), ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) shallow_tree = input_tree (input_tree_flattened_as_shallow_tree_paths, input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree_paths, [("a", "a", 0), ("a", "a", 1, "b"), ("a", "b"), ("c", "d"), ("c", "e", "f")]) self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) (input_tree_flattened_as_shallow_tree_paths, input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree_paths, [("a",), ("c", "d"), ("c", "e")]) self.assertEqual(input_tree_flattened_as_shallow_tree, [ab_tuple(a=[0, {"b": 1}], b=2), 3, collections.OrderedDict([("f", 4)])]) shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) (input_tree_flattened_as_shallow_tree_paths, input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree, input_tree) self.assertEqual(input_tree_flattened_as_shallow_tree_paths, [("a",), ("c",)]) self.assertEqual(input_tree_flattened_as_shallow_tree, [ab_tuple(a=[0, {"b": 1}], b=2), {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) ## Shallow non-list edge-case. # Using iterable elements. input_tree = ["input_tree"] shallow_tree = "shallow_tree" (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree_paths, [()]) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree_paths, [()]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) input_tree = ["input_tree_0", "input_tree_1"] shallow_tree = "shallow_tree" (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree_paths, [()]) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree_paths, [()]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) # Test case where len(shallow_tree) < len(input_tree) input_tree = {"a": "A", "b": "B", "c": "C"} shallow_tree = {"a": 1, "c": 2} # Using non-iterable elements. input_tree = [0] shallow_tree = 9 (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree_paths, [()]) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree_paths, [()]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) input_tree = [0, 1] shallow_tree = 9 (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree_paths, [()]) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree_paths, [()]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) ## Both non-list edge-case. # Using iterable elements. input_tree = "input_tree" shallow_tree = "shallow_tree" (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree_paths, [()]) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree_paths, [()]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) # Using non-iterable elements. input_tree = 0 shallow_tree = 0 (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) self.assertEqual(flattened_input_tree_paths, [()]) self.assertEqual(flattened_input_tree, [input_tree]) self.assertEqual(flattened_shallow_tree_paths, [()]) self.assertEqual(flattened_shallow_tree, [shallow_tree]) ## Input non-list edge-case. # Using iterable elements. input_tree = "input_tree" shallow_tree = ["shallow_tree"] with self.assertRaisesWithLiteralMatch( TypeError, tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format( path=[], input_type=type(input_tree))): (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree_paths, [(0,)]) self.assertEqual(flattened_shallow_tree, shallow_tree) input_tree = "input_tree" shallow_tree = ["shallow_tree_9", "shallow_tree_8"] with self.assertRaisesWithLiteralMatch( TypeError, tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format( path=[], input_type=type(input_tree))): (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree_paths, [(0,), (1,)]) self.assertEqual(flattened_shallow_tree, shallow_tree) # Using non-iterable elements. input_tree = 0 shallow_tree = [9] with self.assertRaisesWithLiteralMatch( TypeError, tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format( path=[], input_type=type(input_tree))): (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree_paths, [(0,)]) self.assertEqual(flattened_shallow_tree, shallow_tree) input_tree = 0 shallow_tree = [9, 8] with self.assertRaisesWithLiteralMatch( TypeError, tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format( path=[], input_type=type(input_tree))): (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree) (flattened_shallow_tree_paths, flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree) self.assertEqual(flattened_shallow_tree_paths, [(0,), (1,)]) self.assertEqual(flattened_shallow_tree, shallow_tree) # Test that error messages include paths. input_tree = {"a": {"b": {0, 1}}} structure = {"a": {"b": [0, 1]}} with self.assertRaisesWithLiteralMatch( TypeError, tree._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ_WITH_PATH.format( path=["a", "b"], input_type=type(input_tree["a"]["b"]))): (flattened_input_tree_paths, flattened_input_tree) = get_paths_and_values(structure, input_tree) (flattened_tree_paths, flattened_tree) = get_paths_and_values(structure, structure) self.assertEqual(flattened_tree_paths, [("a", "b", 0,), ("a", "b", 1,)]) self.assertEqual(flattened_tree, structure["a"]["b"]) def testMapStructureUpTo(self): # Named tuples. ab_tuple = collections.namedtuple("ab_tuple", "a, b") op_tuple = collections.namedtuple("op_tuple", "add, mul") inp_val = ab_tuple(a=2, b=3) inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) out = tree.map_structure_up_to( inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops, check_types=False) self.assertEqual(out.a, 6) self.assertEqual(out.b, 15) # Lists. data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] name_list = ["evens", ["odds", "primes"]] out = tree.map_structure_up_to( name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), name_list, data_list) self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]) # We cannot define namedtuples within @parameterized argument lists. # pylint: disable=invalid-name Foo = collections.namedtuple("Foo", ["a", "b"]) Bar = collections.namedtuple("Bar", ["c", "d"]) # pylint: enable=invalid-name @parameterized.parameters([ dict(inputs=[], expected=[]), dict(inputs=[23, "42"], expected=[((0,), 23), ((1,), "42")]), dict(inputs=[[[[108]]]], expected=[((0, 0, 0, 0), 108)]), dict(inputs=Foo(a=3, b=Bar(c=23, d=42)), expected=[(("a",), 3), (("b", "c"), 23), (("b", "d"), 42)]), dict(inputs=Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="thing")), expected=[(("a", "c"), 23), (("a", "d"), 42), (("b", "c"), 0), (("b", "d"), "thing")]), dict(inputs=Bar(c=42, d=43), expected=[(("c",), 42), (("d",), 43)]), dict(inputs=Bar(c=[42], d=43), expected=[(("c", 0), 42), (("d",), 43)]), dict(inputs=wrapt.ObjectProxy(Bar(c=[42], d=43)), expected=[(("c", 0), 42), (("d",), 43)]), ]) def testFlattenWithPath(self, inputs, expected): self.assertEqual(tree.flatten_with_path(inputs), expected) @parameterized.named_parameters([ dict(testcase_name="Tuples", s1=(1, 2), s2=(3, 4), check_types=True, expected=(((0,), 4), ((1,), 6))), dict(testcase_name="Dicts", s1={"a": 1, "b": 2}, s2={"b": 4, "a": 3}, check_types=True, expected={"a": (("a",), 4), "b": (("b",), 6)}), dict(testcase_name="Mixed", s1=(1, 2), s2=[3, 4], check_types=False, expected=(((0,), 4), ((1,), 6))), dict(testcase_name="Nested", s1={"a": [2, 3], "b": [1, 2, 3]}, s2={"b": [5, 6, 7], "a": [8, 9]}, check_types=True, expected={"a": [(("a", 0), 10), (("a", 1), 12)], "b": [(("b", 0), 6), (("b", 1), 8), (("b", 2), 10)]}), ]) def testMapWithPathCompatibleStructures(self, s1, s2, check_types, expected): def path_and_sum(path, *values): return path, sum(values) result = tree.map_structure_with_path( path_and_sum, s1, s2, check_types=check_types) self.assertEqual(expected, result) @parameterized.named_parameters([ dict(testcase_name="Tuples", s1=(1, 2, 3), s2=(4, 5), error_type=ValueError), dict(testcase_name="Dicts", s1={"a": 1}, s2={"b": 2}, error_type=ValueError), dict(testcase_name="Nested", s1={"a": [2, 3, 4], "b": [1, 3]}, s2={"b": [5, 6], "a": [8, 9]}, error_type=ValueError) ]) def testMapWithPathIncompatibleStructures(self, s1, s2, error_type): with self.assertRaises(error_type): tree.map_structure_with_path(lambda path, *s: 0, s1, s2) def testMappingProxyType(self): structure = types.MappingProxyType({"a": 1, "b": (2, 3)}) expected = types.MappingProxyType({"a": 4, "b": (5, 6)}) self.assertEqual(tree.flatten(structure), [1, 2, 3]) self.assertEqual(tree.unflatten_as(structure, [4, 5, 6]), expected) self.assertEqual(tree.map_structure(lambda v: v + 3, structure), expected) def testTraverseListsToTuples(self): structure = [(1, 2), [3], {"a": [4]}] self.assertEqual( ((1, 2), (3,), {"a": (4,)}), tree.traverse( lambda x: tuple(x) if isinstance(x, list) else x, structure, top_down=False)) def testTraverseEarlyTermination(self): structure = [(1, [2]), [3, (4, 5, 6)]] visited = [] def visit(x): visited.append(x) return "X" if isinstance(x, tuple) and len(x) > 2 else None output = tree.traverse(visit, structure) self.assertEqual([(1, [2]), [3, "X"]], output) self.assertEqual( [[(1, [2]), [3, (4, 5, 6)]], (1, [2]), 1, [2], 2, [3, (4, 5, 6)], 3, (4, 5, 6)], visited) def testMapStructureAcrossSubtreesDict(self): shallow = {"a": 1, "b": {"c": 2}} deep1 = {"a": 2, "b": {"c": 3, "d": 2}, "e": 4} deep2 = {"a": 3, "b": {"c": 2, "d": 3}, "e": 1} summed = tree.map_structure_up_to( shallow, lambda *args: sum(args), deep1, deep2) expected = {"a": 5, "b": {"c": 5}} self.assertEqual(summed, expected) concatenated = tree.map_structure_up_to( shallow, lambda *args: args, deep1, deep2) expected = {"a": (2, 3), "b": {"c": (3, 2)}} self.assertEqual(concatenated, expected) def testMapStructureAcrossSubtreesNoneValues(self): shallow = [1, [None]] deep1 = [1, [2, 3]] deep2 = [2, [3, 4]] summed = tree.map_structure_up_to( shallow, lambda *args: sum(args), deep1, deep2) expected = [3, [5]] self.assertEqual(summed, expected) def testMapStructureAcrossSubtreesList(self): shallow = [1, [1]] deep1 = [1, [2, 3]] deep2 = [2, [3, 4]] summed = tree.map_structure_up_to( shallow, lambda *args: sum(args), deep1, deep2) expected = [3, [5]] self.assertEqual(summed, expected) def testMapStructureAcrossSubtreesTuple(self): shallow = (1, (1,)) deep1 = (1, (2, 3)) deep2 = (2, (3, 4)) summed = tree.map_structure_up_to( shallow, lambda *args: sum(args), deep1, deep2) expected = (3, (5,)) self.assertEqual(summed, expected) def testMapStructureAcrossSubtreesNamedTuple(self): Foo = collections.namedtuple("Foo", ["x", "y"]) Bar = collections.namedtuple("Bar", ["x"]) shallow = Bar(1) deep1 = Foo(1, (1, 0)) deep2 = Foo(2, (2, 0)) summed = tree.map_structure_up_to( shallow, lambda *args: sum(args), deep1, deep2) expected = Bar(3) self.assertEqual(summed, expected) def testMapStructureAcrossSubtreesListTuple(self): # Tuples and lists can be used interchangeably between shallow structure # and input structures. Output takes on type of the shallow structure shallow = [1, (1,)] deep1 = [1, [2, 3]] deep2 = [2, [3, 4]] summed = tree.map_structure_up_to(shallow, lambda *args: sum(args), deep1, deep2) expected = [3, (5,)] self.assertEqual(summed, expected) shallow = [1, [1]] deep1 = [1, (2, 3)] deep2 = [2, (3, 4)] summed = tree.map_structure_up_to(shallow, lambda *args: sum(args), deep1, deep2) expected = [3, [5]] self.assertEqual(summed, expected) def testNoneNodeIncluded(self): structure = ((1, None)) self.assertEqual(tree.flatten(structure), [1, None]) def testCustomClassMapWithPath(self): class ExampleClass(Mapping[Any, Any]): """Small example custom class.""" def __init__(self, *args, **kwargs): self._mapping = dict(*args, **kwargs) def __getitem__(self, k: Any) -> Any: return self._mapping[k] def __len__(self) -> int: return len(self._mapping) def __iter__(self) -> Iterator[Any]: return iter(self._mapping) def mapper(path, value): full_path = "/".join(path) return f"{full_path}_{value}" test_input = ExampleClass({"first": 1, "nested": {"second": 2, "third": 3}}) output = tree.map_structure_with_path(mapper, test_input) expected = ExampleClass({ "first": "first_1", "nested": { "second": "nested/second_2", "third": "nested/third_3" } }) self.assertEqual(output, expected) if __name__ == "__main__": unittest.main()